/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseVector;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRating", "socialmf", "userFactors", "itemFactors"})
public class SocialMFRecommender
extends SocialRecommender {
    @Override
    public void setup() throws LibrecException {
        super.setup();
        this.userFactors.init(1.0);
        this.itemFactors.init(1.0);
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix tempItemFactors = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double rating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx, false);
                double error = Maths.logistic(predictRating) - this.normalize(rating);
                this.loss += error * error;
                double deriValue = Maths.logisticGradientValue(predictRating) * error;
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                    tempUserFactors.add(userIdx, factorIdx, deriValue * itemFactorValue + (double)this.regUser * userFactorValue);
                    tempItemFactors.add(itemIdx, factorIdx, deriValue * userFactorValue + (double)this.regItem * itemFactorValue);
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * itemFactorValue * itemFactorValue;
                }
            }
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                SparseVector userTrustVector = this.socialMatrix.row(userIdx);
                double trustSum = userTrustVector.sum();
                if (trustSum <= 0.0) continue;
                double[] sumNNs = new double[this.numFactors];
                for (int trustUserIdx : userTrustVector.getIndex()) {
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        int n = factorIdx;
                        sumNNs[n] = sumNNs[n] + this.socialMatrix.get(userIdx, trustUserIdx) * this.userFactors.get(trustUserIdx, factorIdx);
                    }
                }
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double diffValue = this.userFactors.get(userIdx, factorIdx) - sumNNs[factorIdx] / trustSum;
                    tempUserFactors.add(userIdx, factorIdx, (double)this.regSocial * diffValue);
                    this.loss += (double)this.regSocial * diffValue * diffValue;
                }
                SparseVector userTrustedVector = this.socialMatrix.column(userIdx);
                double trustedSum = userTrustedVector.sum();
                for (int trustedUserIdx : userTrustedVector.getIndex()) {
                    double trustedValue = this.socialMatrix.get(trustedUserIdx, userIdx);
                    SparseVector trustedTrustVector = this.socialMatrix.row(trustedUserIdx);
                    double[] sumDiffs = new double[this.numFactors];
                    for (int trustedTrustUserIdx : trustedTrustVector.getIndex()) {
                        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                            int n = factorIdx;
                            sumDiffs[n] = sumDiffs[n] + this.socialMatrix.get(trustedUserIdx, trustedTrustUserIdx) * this.userFactors.get(trustedTrustUserIdx, factorIdx);
                        }
                    }
                    trustSum = trustedTrustVector.sum();
                    if (!(trustSum > 0.0)) continue;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        tempUserFactors.add(userIdx, factorIdx, (double)(-this.regSocial) * (trustedValue / trustedSum) * (this.userFactors.get(trustedUserIdx, factorIdx) - sumDiffs[factorIdx] / trustSum));
                    }
                }
            }
            this.userFactors = this.userFactors.add(tempUserFactors.scale(-this.learnRate));
            this.itemFactors = this.itemFactors.add(tempItemFactors.scale(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }
}

