/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import java.util.ArrayList;
import java.util.List;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRating", "sorec", "userFactors", "itemFactors"})
public class SoRecRecommender
extends SocialRecommender {
    private DenseMatrix userSocialFactors;
    private float regRateSocial;
    private float regUserSocial;
    private List<Integer> inDegrees;
    private List<Integer> outDegrees;

    @Override
    public void setup() throws LibrecException {
        super.setup();
        this.userFactors.init(1.0);
        this.itemFactors.init(1.0);
        this.regRateSocial = this.conf.getFloat("rec.rate.social.regularization", Float.valueOf(0.01f)).floatValue();
        this.regUserSocial = this.conf.getFloat("rec.user.social.regularization", Float.valueOf(0.01f)).floatValue();
        this.userSocialFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.userSocialFactors.init(1.0);
        this.inDegrees = new ArrayList<Integer>();
        this.outDegrees = new ArrayList<Integer>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int in = this.socialMatrix.columnSize(userIdx);
            int out = this.socialMatrix.rowSize(userIdx);
            this.inDegrees.add(in);
            this.outDegrees.add(out);
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            int userIdx;
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix tempItemFactors = new DenseMatrix(this.numItems, this.numFactors);
            DenseMatrix userSocialTempFactors = new DenseMatrix(this.numUsers, this.numFactors);
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double rating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx);
                double error = Maths.logistic(predictRating) - Maths.normalize(rating, this.minRate, this.maxRate);
                this.loss += error * error;
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                    tempUserFactors.add(userIdx, factorIdx, Maths.logisticGradientValue(predictRating) * error * itemFactorValue + (double)this.regUser * userFactorValue);
                    tempItemFactors.add(itemIdx, factorIdx, Maths.logisticGradientValue(predictRating) * error * userFactorValue + (double)this.regItem * itemFactorValue);
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * itemFactorValue * itemFactorValue;
                }
            }
            for (MatrixEntry matrixEntry : this.socialMatrix) {
                userIdx = matrixEntry.row();
                int userSocialIdx = matrixEntry.column();
                double socialValue = matrixEntry.get();
                if (socialValue <= 0.0) continue;
                double socialPredictRating = DenseMatrix.rowMult(this.userFactors, userIdx, this.userSocialFactors, userSocialIdx);
                int userSocialInDegree = this.inDegrees.get(userSocialIdx);
                int userOutDegree = this.outDegrees.get(userIdx);
                double weight = Math.sqrt((double)userSocialInDegree / ((double)(userOutDegree + userSocialInDegree) + 0.0));
                double socialError = Maths.logistic(socialPredictRating) - weight * socialValue;
                this.loss += (double)this.regRateSocial * socialError * socialError;
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double userSocialFactorValue = this.userSocialFactors.get(userSocialIdx, factorIdx);
                    tempUserFactors.add(userIdx, factorIdx, (double)this.regRateSocial * Maths.logisticGradientValue(socialPredictRating) * socialError * userSocialFactorValue);
                    userSocialTempFactors.add(userSocialIdx, factorIdx, (double)this.regRateSocial * Maths.logisticGradientValue(socialPredictRating) * socialError * userFactorValue + (double)this.regUserSocial * userSocialFactorValue);
                    this.loss += (double)this.regUserSocial * userSocialFactorValue * userSocialFactorValue;
                }
            }
            this.userFactors = this.userFactors.add(tempUserFactors.scale(-this.learnRate));
            this.itemFactors = this.itemFactors.add(tempItemFactors.scale(-this.learnRate));
            this.userSocialFactors = this.userSocialFactors.add(userSocialTempFactors.scale(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }
}

