/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRating", "trustmf", "trusterUserTrusterFactors", "trusterUserTrusteeFactors", "trusteeUserTrusterFactors", "trusteeUserTrusteeFactors", "model"})
public class TrustMFRecommender
extends SocialRecommender {
    protected DenseMatrix trusterUserTrusterFactors;
    protected DenseMatrix trusterUserTrusteeFactors;
    protected DenseMatrix trusterItemFactors;
    protected DenseMatrix trusteeUserTrusterFactors;
    protected DenseMatrix trusteeUserTrusteeFactors;
    protected DenseMatrix trusteeItemFactors;
    protected String model;

    @Override
    public void setup() throws LibrecException {
        super.setup();
        switch (this.model = this.conf.get("rec.social.model", "T")) {
            case "Tr": {
                this.initTr();
                break;
            }
            case "Te": {
                this.initTe();
                break;
            }
            default: {
                this.initTr();
                this.initTe();
            }
        }
    }

    protected void initTr() {
        this.trusterUserTrusterFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.trusterUserTrusteeFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.trusterItemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.trusterUserTrusterFactors.init();
        this.trusterUserTrusteeFactors.init();
        this.trusterItemFactors.init();
    }

    protected void initTe() {
        this.trusteeUserTrusterFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.trusteeUserTrusteeFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.trusteeItemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.trusteeUserTrusterFactors.init();
        this.trusteeUserTrusteeFactors.init();
        this.trusteeItemFactors.init();
    }

    @Override
    protected void trainModel() throws LibrecException {
        switch (this.model) {
            case "Tr": {
                this.TrusterMF();
                break;
            }
            case "Te": {
                this.TrusteeMF();
                break;
            }
            default: {
                this.TrusterMF();
                this.TrusteeMF();
            }
        }
    }

    protected void TrusterMF() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            double trusterUserTrusterFactorValue;
            int factorIdx;
            double deriValue;
            int userIdx;
            this.loss = 0.0;
            DenseMatrix userTrusterGradients = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix userTrusteeGradients = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix itemGradients = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double rating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx, false);
                double error = Maths.logistic(predictRating) - this.normalize(rating);
                this.loss += error * error;
                deriValue = Maths.logisticGradientValue(predictRating) * error;
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    trusterUserTrusterFactorValue = this.trusterUserTrusterFactors.get(userIdx, factorIdx);
                    double trusterItemFactorValue = this.trusterItemFactors.get(itemIdx, factorIdx);
                    userTrusterGradients.add(userIdx, factorIdx, deriValue * trusterItemFactorValue + (double)this.regUser * trusterUserTrusterFactorValue);
                    itemGradients.add(itemIdx, factorIdx, deriValue * trusterUserTrusterFactorValue + (double)this.regItem * trusterItemFactorValue);
                    this.loss += (double)this.regUser * trusterUserTrusterFactorValue * trusterUserTrusterFactorValue + (double)this.regItem * trusterItemFactorValue * trusterItemFactorValue;
                }
            }
            for (MatrixEntry matrixEntry : this.socialMatrix) {
                userIdx = matrixEntry.row();
                int userSocialIdx = matrixEntry.column();
                double socialValue = matrixEntry.get();
                if (!(socialValue > 0.0)) continue;
                double preddictSocialValue = DenseMatrix.rowMult(this.trusterUserTrusterFactors, userIdx, this.trusterUserTrusteeFactors, userSocialIdx);
                double socialError = Maths.logistic(preddictSocialValue) - socialValue;
                this.loss += (double)this.regSocial * socialError * socialError;
                deriValue = Maths.logisticGradientValue(preddictSocialValue) * socialError;
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    trusterUserTrusterFactorValue = this.trusterUserTrusterFactors.get(userIdx, factorIdx);
                    double trusterUserTrusteeFactorValue = this.trusterUserTrusteeFactors.get(userSocialIdx, factorIdx);
                    userTrusterGradients.add(userIdx, factorIdx, (double)this.regSocial * deriValue * trusterUserTrusteeFactorValue + (double)this.regUser * trusterUserTrusterFactorValue);
                    userTrusteeGradients.add(userSocialIdx, factorIdx, (double)this.regSocial * deriValue * trusterUserTrusterFactorValue + (double)this.regUser * trusterUserTrusteeFactorValue);
                    this.loss += (double)this.regUser * trusterUserTrusterFactorValue * trusterUserTrusterFactorValue + (double)this.regUser * trusterUserTrusteeFactorValue * trusterUserTrusteeFactorValue;
                }
            }
            this.trusterUserTrusterFactors.addEqual(userTrusterGradients.scale(-this.learnRate));
            this.trusterItemFactors.addEqual(itemGradients.scale(-this.learnRate));
            this.trusterUserTrusteeFactors.addEqual(userTrusteeGradients.scale(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    protected void TrusteeMF() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            double trusteeUserTrusteeFactorValue;
            int factorIdx;
            double deriValue;
            this.loss = 0.0;
            DenseMatrix userTrusterGradients = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix userTrusteeGradients = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix itemGradients = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double rating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx, false);
                double error = Maths.logistic(predictRating) - this.normalize(rating);
                this.loss += error * error;
                deriValue = Maths.logisticGradientValue(predictRating) * error;
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    trusteeUserTrusteeFactorValue = this.trusteeUserTrusteeFactors.get(userIdx, factorIdx);
                    double trusteeItemFactorValue = this.trusteeItemFactors.get(itemIdx, factorIdx);
                    userTrusteeGradients.add(userIdx, factorIdx, deriValue * trusteeItemFactorValue + (double)this.regUser * trusteeUserTrusteeFactorValue);
                    itemGradients.add(itemIdx, factorIdx, deriValue * trusteeUserTrusteeFactorValue + (double)this.regItem * trusteeItemFactorValue);
                    this.loss += (double)this.regUser * trusteeUserTrusteeFactorValue * trusteeUserTrusteeFactorValue + (double)this.regItem * trusteeItemFactorValue * trusteeItemFactorValue;
                }
            }
            for (MatrixEntry matrixEntry : this.socialMatrix) {
                int userSocialIdx = matrixEntry.row();
                int userIdx = matrixEntry.column();
                double socialValue = matrixEntry.get();
                if (!(socialValue > 0.0)) continue;
                double predictSocialValue = DenseMatrix.rowMult(this.trusteeUserTrusterFactors, userSocialIdx, this.trusteeUserTrusteeFactors, userIdx);
                double socialError = Maths.logistic(predictSocialValue) - socialValue;
                this.loss += (double)this.regSocial * socialError * socialError;
                deriValue = Maths.logisticGradientValue(predictSocialValue) * socialError;
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    trusteeUserTrusteeFactorValue = this.trusteeUserTrusteeFactors.get(userIdx, factorIdx);
                    double trusteeUserTrusterFactorValue = this.trusteeUserTrusterFactors.get(userSocialIdx, factorIdx);
                    userTrusteeGradients.add(userIdx, factorIdx, (double)this.regSocial * deriValue * trusteeUserTrusterFactorValue + (double)this.regUser * trusteeUserTrusteeFactorValue);
                    userTrusterGradients.add(userSocialIdx, factorIdx, (double)this.regSocial * deriValue * trusteeUserTrusteeFactorValue + (double)this.regUser * trusteeUserTrusterFactorValue);
                    this.loss += (double)this.regUser * trusteeUserTrusteeFactorValue * trusteeUserTrusteeFactorValue + (double)this.regUser * trusteeUserTrusterFactorValue * trusteeUserTrusterFactorValue;
                }
            }
            this.trusteeUserTrusterFactors.addEqual(userTrusterGradients.scale(-this.learnRate));
            this.trusteeItemFactors.addEqual(itemGradients.scale(-this.learnRate));
            this.trusteeUserTrusteeFactors.addEqual(userTrusteeGradients.scale(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected void updateLRate(int iter) {
        if (iter == 10) {
            this.learnRate = (float)((double)this.learnRate * 0.6);
        } else if (iter == 30) {
            this.learnRate = (float)((double)this.learnRate * 0.333);
        } else if (iter == 100) {
            this.learnRate = (float)((double)this.learnRate * 0.5);
        }
        this.lastLoss = this.loss;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        double predictRating;
        switch (this.model) {
            case "Tr": {
                predictRating = DenseMatrix.rowMult(this.trusterUserTrusterFactors, userIdx, this.trusterItemFactors, itemIdx);
                break;
            }
            case "Te": {
                predictRating = DenseMatrix.rowMult(this.trusteeUserTrusteeFactors, userIdx, this.trusteeItemFactors, itemIdx);
                break;
            }
            default: {
                DenseVector userVector = this.trusterUserTrusterFactors.row(userIdx).add(this.trusteeUserTrusteeFactors.row(userIdx, false));
                DenseVector itemVector = this.trusterItemFactors.row(itemIdx).add(this.trusteeItemFactors.row(itemIdx, false));
                predictRating = userVector.inner(itemVector) / 4.0;
            }
        }
        return predictRating;
    }
}

