/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.SymmMatrix;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRating", "soreg", "userFactors", "itemFactors"})
public class SoRegRecommender
extends SocialRecommender {
    private SymmMatrix userSocialCorrs;

    @Override
    public void setup() throws LibrecException {
        super.setup();
        this.userFactors.init(1.0);
        this.itemFactors.init(1.0);
        this.userSocialCorrs = this.context.getSimilarity().getSimilarityMatrix();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            for (int simUserIdx = userIdx + 1; simUserIdx < this.numUsers; ++simUserIdx) {
                if (!this.userSocialCorrs.contains(userIdx, simUserIdx)) continue;
                double sim = this.userSocialCorrs.get(userIdx, simUserIdx);
                sim = (1.0 + sim) / 2.0;
                this.userSocialCorrs.set(userIdx, simUserIdx, sim);
            }
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix tempItemFactors = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double realRating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx);
                double error = predictRating - realRating;
                this.loss += error * error;
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                    tempUserFactors.add(userIdx, factorIdx, error * itemFactorValue + (double)this.regUser * userFactorValue);
                    tempItemFactors.add(itemIdx, factorIdx, error * userFactorValue + (double)this.regItem * itemFactorValue);
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * itemFactorValue * itemFactorValue;
                }
            }
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                SparseVector userOutLinks = this.socialMatrix.row(userIdx);
                for (int userOutIdx : userOutLinks.getIndex()) {
                    double userOutSim = this.userSocialCorrs.get(userIdx, userOutIdx);
                    if (Double.isNaN(userOutSim)) continue;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double errorOut = this.userFactors.get(userIdx, factorIdx) - this.userFactors.get(userOutIdx, factorIdx);
                        tempUserFactors.add(userIdx, factorIdx, (double)this.regSocial * userOutSim * errorOut);
                        this.loss += (double)this.regSocial * userOutSim * errorOut * errorOut;
                    }
                }
                SparseVector userInLinks = this.socialMatrix.column(userIdx);
                for (int userInIdx : userInLinks.getIndex()) {
                    double userInSim = this.userSocialCorrs.get(userIdx, userInIdx);
                    if (Double.isNaN(userInSim)) continue;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double errorIn = this.userFactors.get(userIdx, factorIdx) - this.userFactors.get(userInIdx, factorIdx);
                        tempUserFactors.add(userIdx, factorIdx, (double)this.regSocial * userInSim * errorIn);
                        this.loss += (double)this.regSocial * userInSim * errorIn * errorIn;
                    }
                }
            }
            this.userFactors.addEqual(tempUserFactors.scale(-this.learnRate));
            this.itemFactors.addEqual(tempItemFactors.scale(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx, boolean bound) throws LibrecException {
        double predictRating = this.predict(userIdx, itemIdx);
        if (bound) {
            if (predictRating > this.maxRate) {
                predictRating = this.maxRate;
            } else if (predictRating < this.minRate) {
                predictRating = this.minRate;
            }
        }
        return predictRating;
    }
}

