/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import java.util.Iterator;
import java.util.List;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.MatrixFactorizationRecommender;

public class ListRankMFRecommender
extends MatrixFactorizationRecommender {
    public DenseVector userExp;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.userFactors.init(1.0);
        this.userFactors.scale(0.1);
        this.itemFactors.init(1.0);
        this.itemFactors.scale(0.1);
        this.userExp = new DenseVector(this.numUsers);
        for (MatrixEntry matrixentry : this.trainMatrix) {
            int userIdx = matrixentry.row();
            double realRating = matrixentry.get() / this.maxRate;
            this.userExp.add(userIdx, Math.exp(realRating));
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        double lastLoss = this.getLoss(this.userFactors, this.itemFactors);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            DenseMatrix lastUserFactors = this.userFactors;
            DenseMatrix lastItemFactors = this.itemFactors;
            this.learnRate *= 2.0f;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix tempItemFactors = new DenseMatrix(this.numItems, this.numFactors);
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                double uexp = 0.0;
                List<Integer> items = this.trainMatrix.getColumns(userIdx);
                Iterator<Object> iterator = items.iterator();
                while (iterator.hasNext()) {
                    int itemIdx = iterator.next();
                    uexp += Math.exp(Maths.logistic(DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx)));
                }
                for (VectorEntry vectorEntry : this.trainMatrix.row(userIdx)) {
                    int itemIdx = vectorEntry.index();
                    double realRating = vectorEntry.get() / this.maxRate;
                    double predictRating = DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx);
                    double normalizedRealRating = Math.exp(realRating) / this.userExp.get(userIdx);
                    double normalizedPredictRating = Math.exp(Maths.logistic(predictRating)) / uexp;
                    double error = (normalizedPredictRating - normalizedRealRating) * Maths.logisticGradientValue(predictRating);
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                        double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                        double userGradientValue = error * itemFactorValue;
                        tempUserFactors.add(userIdx, factorIdx, userGradientValue);
                        double itemGradientValue = error * userFactorValue;
                        tempItemFactors.add(itemIdx, factorIdx, itemGradientValue);
                    }
                }
            }
            this.userFactors = this.userFactors.add(this.userFactors.scale(-this.learnRate * this.regUser));
            this.userFactors = this.userFactors.add(tempUserFactors.scale(-this.learnRate));
            this.itemFactors = this.itemFactors.add(this.itemFactors.scale(-this.learnRate * this.regItem));
            this.itemFactors = this.itemFactors.add(tempItemFactors.scale(-this.learnRate));
            this.loss = this.getLoss(this.userFactors, this.itemFactors);
            while (this.loss > lastLoss) {
                this.userFactors = lastUserFactors;
                this.itemFactors = lastItemFactors;
                this.learnRate /= 2.0f;
                this.userFactors = this.userFactors.add(this.userFactors.scale(-this.learnRate * this.regUser));
                this.userFactors = this.userFactors.add(tempUserFactors.scale(-this.learnRate));
                this.itemFactors = this.itemFactors.add(this.itemFactors.scale(-this.learnRate * this.regItem));
                this.itemFactors = this.itemFactors.add(tempItemFactors.scale(-this.learnRate));
                this.loss = this.getLoss(this.userFactors, this.itemFactors);
            }
            String info = " iter " + iter + ": loss = " + this.loss + ", delta_loss = " + (lastLoss - this.loss);
            this.LOG.info(info);
            lastLoss = this.loss;
        }
    }

    public double getLoss(DenseMatrix userFactors, DenseMatrix itemFactors) {
        double loss = 0.0;
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            double uexp = 0.0;
            List<Integer> items = this.trainMatrix.getColumns(userIdx);
            for (int itemIdx : items) {
                uexp += Math.exp(Maths.logistic(DenseMatrix.rowMult(userFactors, userIdx, itemFactors, itemIdx)));
            }
            Iterator<VectorEntry> itemVectorIterator = this.trainMatrix.colIterator(userIdx);
            while (itemVectorIterator.hasNext()) {
                VectorEntry itemEntry = itemVectorIterator.next();
                int itemIdx = itemEntry.index();
                double realRating = itemEntry.get() / this.maxRate;
                double predictRating = DenseMatrix.rowMult(userFactors, userIdx, itemFactors, itemIdx);
                loss -= Math.exp(realRating) / this.userExp.get(userIdx) * Math.log(Math.exp(Maths.logistic(predictRating)) / uexp);
            }
            for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                double userFactorValue = userFactors.get(userIdx, factorIdx);
                loss += 0.5 * (double)this.regUser * userFactorValue * userFactorValue;
            }
        }
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                double itemFactorValue = itemFactors.get(itemIdx, factorIdx);
                loss += 0.5 * (double)this.regItem * itemFactorValue * itemFactorValue;
            }
        }
        return loss;
    }
}

