/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;
import net.librec.util.Lists;

@ModelData(value={"isRanking", "ranksgd", "userFactors", "itemFactors", "trainMatrix"})
public class RankSGDRecommender
extends MatrixFactorizationRecommender {
    protected List<Map.Entry<Integer, Double>> itemProbs;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        HashMap<Integer, Double> itemProbsMap = new HashMap<Integer, Double>();
        for (int j = 0; j < this.numItems; ++j) {
            int users = this.trainMatrix.columnSize(j);
            double prob = ((double)users + 0.0) / (double)this.numRates;
            if (!(prob > 0.0)) continue;
            itemProbsMap.put(j, prob);
        }
        this.itemProbs = Lists.sortMap(itemProbsMap);
    }

    @Override
    protected void trainModel() throws LibrecException {
        List<Set<Integer>> userItemsSet = this.getUserItemsSet(this.trainMatrix);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int posItemIdx = matrixEntry.column();
                double posRating = matrixEntry.get();
                int negItemIdx = -1;
                block2: do {
                    double sum = 0.0;
                    double rand = Randoms.random();
                    for (Map.Entry<Integer, Double> itemProb : this.itemProbs) {
                        int itemIdx = itemProb.getKey();
                        double prob = itemProb.getValue();
                        if (!((sum += prob) >= rand)) continue;
                        negItemIdx = itemIdx;
                        continue block2;
                    }
                } while (userItemsSet.get(userIdx).contains(negItemIdx));
                double negRating = 0.0;
                double posPredictRating = this.predict(userIdx, posItemIdx);
                double negPredictRating = this.predict(userIdx, negItemIdx);
                double error = posPredictRating - negPredictRating - (posRating - negRating);
                this.loss += error * error;
                double sgd = (double)this.learnRate * error;
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                    double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                    this.userFactors.add(userIdx, factorIdx, -sgd * (posItemFactorValue - negItemFactorValue));
                    this.itemFactors.add(posItemIdx, factorIdx, -sgd * userFactorValue);
                    this.itemFactors.add(negItemIdx, factorIdx, sgd * userFactorValue);
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    private List<Set<Integer>> getUserItemsSet(SparseMatrix sparseMatrix) {
        ArrayList<Set<Integer>> userItemsSet = new ArrayList<Set<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            userItemsSet.add(new HashSet<Integer>(sparseMatrix.getColumns(userIdx)));
        }
        return userItemsSet;
    }
}

