/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRating", "rste", "userFactors", "itemFactors", "userSocialRatio", "socialMatrix"})
public class RSTERecommender
extends SocialRecommender {
    private float userSocialRatio;

    @Override
    public void setup() throws LibrecException {
        super.setup();
        this.userFactors.init(1.0);
        this.itemFactors.init(1.0);
        this.userSocialRatio = this.conf.getFloat("rec.user.social.ratio", Float.valueOf(0.8f)).floatValue();
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            double sum;
            double predictRating;
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix tempItemFactors = new DenseMatrix(this.numItems, this.numFactors);
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                SparseVector userSoicalValues = this.socialMatrix.row(userIdx);
                int[] userSocialIndice = userSoicalValues.getIndex();
                double weightSocialSum = 0.0;
                for (int userSoicalIdx : userSocialIndice) {
                    weightSocialSum += userSoicalValues.get(userSoicalIdx);
                }
                double[] sumUserSocialFactor = new double[this.numFactors];
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    int userSoicalIdx;
                    int[] nArray = userSocialIndice;
                    userSoicalIdx = nArray.length;
                    for (int i = 0; i < userSoicalIdx; ++i) {
                        int userSoicalIdx2 = nArray[i];
                        int n = factorIdx;
                        sumUserSocialFactor[n] = sumUserSocialFactor[n] + userSoicalValues.get(userSoicalIdx2) * this.userFactors.get(userSoicalIdx2, factorIdx);
                    }
                }
                for (VectorEntry vectorEntry : this.trainMatrix.row(userIdx)) {
                    int itemIdx = vectorEntry.index();
                    double rating = vectorEntry.get();
                    double norRating = Maths.normalize(rating, this.minRate, this.maxRate);
                    predictRating = DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx);
                    sum = 0.0;
                    for (int k : userSocialIndice) {
                        sum += userSoicalValues.get(k) * DenseMatrix.rowMult(this.userFactors, k, this.itemFactors, itemIdx);
                    }
                    double socialPredictRating = weightSocialSum > 0.0 ? sum / weightSocialSum : 0.0;
                    double finalPredictRating = (double)this.userSocialRatio * predictRating + (double)(1.0f - this.userSocialRatio) * socialPredictRating;
                    double error = Maths.logistic(finalPredictRating) - norRating;
                    this.loss += error * error;
                    double deriValue = Maths.logisticGradientValue(finalPredictRating) * error;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                        double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                        double userDeriValue = (double)this.userSocialRatio * deriValue * itemFactorValue + (double)this.regUser * userFactorValue;
                        double userSocialFactorValue = weightSocialSum > 0.0 ? sumUserSocialFactor[factorIdx] / weightSocialSum : 0.0;
                        double itemDeriValue = deriValue * ((double)this.userSocialRatio * userFactorValue + (double)(1.0f - this.userSocialRatio) * userSocialFactorValue) + (double)this.regItem * itemFactorValue;
                        tempUserFactors.add(userIdx, factorIdx, userDeriValue);
                        tempItemFactors.add(itemIdx, factorIdx, itemDeriValue);
                        this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * itemFactorValue * itemFactorValue;
                    }
                }
            }
            for (int userSocialIdx = 0; userSocialIdx < this.numUsers; ++userSocialIdx) {
                SparseVector socialUserValues = this.socialMatrix.column(userSocialIdx);
                for (int socialUserIdx : socialUserValues.getIndex()) {
                    if (socialUserIdx >= this.numUsers) continue;
                    SparseVector socialItemValues = this.trainMatrix.row(socialUserIdx);
                    SparseVector socialUserSoicalValues = this.socialMatrix.row(socialUserIdx);
                    int[] socialUserSocialIndices = socialUserSoicalValues.getIndex();
                    for (int socialItemIdx : socialItemValues.getIndex()) {
                        predictRating = DenseMatrix.rowMult(this.userFactors, socialUserIdx, this.itemFactors, socialItemIdx);
                        sum = 0.0;
                        double socialWeightSum = 0.0;
                        for (int socialUserSocialIdx : socialUserSocialIndices) {
                            double socialUserSocialValue = socialUserSoicalValues.get(socialUserSocialIdx);
                            sum += socialUserSocialValue * DenseMatrix.rowMult(this.userFactors, socialUserSocialIdx, this.itemFactors, socialItemIdx);
                            socialWeightSum += socialUserSocialValue;
                        }
                        double socialPredictRating = socialWeightSum > 0.0 ? sum / socialWeightSum : 0.0;
                        double finalPredictRating = (double)this.userSocialRatio * predictRating + (double)(1.0f - this.userSocialRatio) * socialPredictRating;
                        double error = Maths.logistic(finalPredictRating) - Maths.normalize(socialItemValues.get(socialItemIdx), this.minRate, this.maxRate);
                        double deriValue = Maths.logisticGradientValue(finalPredictRating) * error * socialUserValues.get(socialUserIdx);
                        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                            tempUserFactors.add(userSocialIdx, factorIdx, (double)(1.0f - this.userSocialRatio) * deriValue * this.itemFactors.get(socialItemIdx, factorIdx));
                        }
                    }
                }
            }
            this.userFactors = this.userFactors.add(tempUserFactors.scale(-this.learnRate));
            this.itemFactors = this.itemFactors.add(tempItemFactors.scale(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        double predictRating = DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx);
        double sum = 0.0;
        double socialWeightSum = 0.0;
        SparseVector userSocialVector = this.socialMatrix.row(userIdx);
        for (int userSoicalIdx : userSocialVector.getIndex()) {
            double userSocialValue = userSocialVector.get(userSoicalIdx);
            sum += userSocialValue * DenseMatrix.rowMult(this.userFactors, userSoicalIdx, this.itemFactors, itemIdx);
            socialWeightSum += userSocialValue;
        }
        double soicalPredictRatting = socialWeightSum > 0.0 ? sum / socialWeightSum : 0.0;
        predictRating = (double)this.userSocialRatio * predictRating + (double)(1.0f - this.userSocialRatio) * soicalPredictRatting;
        return predictRating;
    }
}

