/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "bpr", "userFactors", "itemFactors"})
public class BPRRecommender
extends MatrixFactorizationRecommender {
    private List<Set<Integer>> userItemsSet;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
    }

    @Override
    protected void trainModel() throws LibrecException {
        this.userItemsSet = this.getUserItemsSet(this.trainMatrix);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            int smax = this.numUsers * 100;
            for (int sampleCount = 0; sampleCount < smax; ++sampleCount) {
                int negItemIdx;
                int userIdx;
                Set<Integer> itemSet;
                while ((itemSet = this.userItemsSet.get(userIdx = Randoms.uniform(this.numUsers))).size() == 0 || itemSet.size() == this.numItems) {
                }
                List<Integer> itemList = this.trainMatrix.getColumns(userIdx);
                int posItemIdx = itemList.get(Randoms.uniform(itemList.size()));
                while (itemSet.contains(negItemIdx = Randoms.uniform(this.numItems))) {
                }
                double posPredictRating = this.predict(userIdx, posItemIdx);
                double negPredictRating = this.predict(userIdx, negItemIdx);
                double diffValue = posPredictRating - negPredictRating;
                double lossValue = -Math.log(Maths.logistic(diffValue));
                this.loss += lossValue;
                double deriValue = Maths.logistic(-diffValue);
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                    double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                    this.userFactors.add(userIdx, factorIdx, (double)this.learnRate * (deriValue * (posItemFactorValue - negItemFactorValue) - (double)this.regUser * userFactorValue));
                    this.itemFactors.add(posItemIdx, factorIdx, (double)this.learnRate * (deriValue * userFactorValue - (double)this.regItem * posItemFactorValue));
                    this.itemFactors.add(negItemIdx, factorIdx, (double)this.learnRate * (deriValue * -userFactorValue - (double)this.regItem * negItemFactorValue));
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * posItemFactorValue * posItemFactorValue + (double)this.regItem * negItemFactorValue * negItemFactorValue;
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    private List<Set<Integer>> getUserItemsSet(SparseMatrix sparseMatrix) {
        ArrayList<Set<Integer>> userItemsSet = new ArrayList<Set<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            userItemsSet.add(new HashSet<Integer>(sparseMatrix.getColumns(userIdx)));
        }
        return userItemsSet;
    }
}

