/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import java.util.ArrayList;
import java.util.List;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.cf.rating.BiasedMFRecommender;

@ModelData(value={"isRating", "svdplusplus", "userFactors", "itemFactors", "userBiases", "itemBiases", "impItemFactors", "trainMatrix"})
public class SVDPlusPlusRecommender
extends BiasedMFRecommender {
    protected DenseMatrix impItemFactors;
    protected List<List<Integer>> userItemsList;
    private double regImpItem;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.regImpItem = this.conf.getDouble("rec.impItem.regularization", 0.015);
        this.impItemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.impItemFactors.init(this.initMean, this.initStd);
        this.userItemsList = this.getUserItemsList(this.trainMatrix);
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double realRating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx);
                double error = realRating - predictRating;
                this.loss += error * error;
                List<Integer> items = this.userItemsList.get(userIdx);
                double userBiasValue = this.userBiases.get(userIdx);
                this.userBiases.add(userIdx, (double)this.learnRate * (error - this.regBias * userBiasValue));
                this.loss += this.regBias * userBiasValue * userBiasValue;
                double itemBiasValue = this.itemBiases.get(itemIdx);
                this.itemBiases.add(itemIdx, (double)this.learnRate * (error - this.regBias * itemBiasValue));
                this.loss += this.regBias * itemBiasValue * itemBiasValue;
                DenseVector sumImpItemsFactors = new DenseVector(this.numFactors);
                for (int impItemIdx : items) {
                    sumImpItemsFactors.addEqual(this.impItemFactors.row(impItemIdx, false));
                }
                double impNor = Math.sqrt(items.size());
                if (impNor > 0.0) {
                    sumImpItemsFactors.scaleEqual(1.0 / impNor);
                }
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                    this.userFactors.add(userIdx, factorIdx, (double)this.learnRate * (error * itemFactorValue - (double)this.regUser * userFactorValue));
                    this.itemFactors.add(itemIdx, factorIdx, (double)this.learnRate * (error * (userFactorValue + sumImpItemsFactors.get(factorIdx)) - (double)this.regItem * itemFactorValue));
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * itemFactorValue * itemFactorValue;
                    for (int impItemIdx : items) {
                        double impItemFactor = this.impItemFactors.get(impItemIdx, factorIdx);
                        this.impItemFactors.add(impItemIdx, factorIdx, (double)this.learnRate * (error * itemFactorValue / impNor - this.regImpItem * impItemFactor));
                        this.loss += this.regImpItem * impItemFactor * impItemFactor;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double predictRating = this.userBiases.get(userIdx) + this.itemBiases.get(itemIdx) + this.globalMean;
        List<Integer> items = this.userItemsList.get(userIdx);
        DenseVector userImpFactor = new DenseVector(this.numFactors);
        for (int impItemIdx : items) {
            userImpFactor.addEqual(this.impItemFactors.row(impItemIdx, false));
        }
        double impNor = Math.sqrt(items.size());
        if (impNor > 0.0) {
            userImpFactor.scaleEqual(1.0 / impNor);
        }
        userImpFactor.addEqual(this.userFactors.row(userIdx, false));
        return predictRating + userImpFactor.inner(this.itemFactors.row(itemIdx, false));
    }

    private List<List<Integer>> getUserItemsList(SparseMatrix sparseMatrix) {
        ArrayList<List<Integer>> userItemsList = new ArrayList<List<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            userItemsList.add(sparseMatrix.getColumns(userIdx));
        }
        return userItemsList;
    }
}

