/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.ext;

import java.util.HashMap;
import java.util.Map;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.SymmMatrix;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.cf.ranking.RankSGDRecommender;
import net.librec.util.Lists;

@ModelData(value={"isRanking", "prankd", "userFactors", "itemFactors", "trainMatrix"})
public class PRankDRecommender
extends RankSGDRecommender {
    private DenseVector itemWeights;
    private SymmMatrix itemCorrs;
    private float simFilter;

    @Override
    protected void setup() throws LibrecException {
        int itemIdx;
        super.setup();
        this.simFilter = this.conf.getFloat("rec.sim.filter", Float.valueOf(4.0f)).floatValue();
        HashMap<Integer, Double> itemProbsMap = new HashMap<Integer, Double>();
        double maxUsersCount = 0.0;
        this.itemWeights = new DenseVector(this.numItems);
        for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            int usersCount = this.trainMatrix.columnSize(itemIdx);
            maxUsersCount = maxUsersCount < (double)usersCount ? (double)usersCount : maxUsersCount;
            this.itemWeights.set(itemIdx, usersCount);
            double prob = ((double)usersCount + 0.0) / (double)this.numRates;
            if (!(prob > 0.0)) continue;
            itemProbsMap.put(itemIdx, prob);
        }
        this.itemProbs = Lists.sortMap(itemProbsMap);
        for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            this.itemWeights.set(itemIdx, this.itemWeights.get(itemIdx) / maxUsersCount);
        }
        this.itemCorrs = this.context.getSimilarity().getSimilarityMatrix();
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (int userIdx : this.trainMatrix.rows()) {
                SparseVector itemRatingsVector = this.trainMatrix.row(userIdx);
                for (VectorEntry itemRatingEntry : itemRatingsVector) {
                    int posItemIdx = itemRatingEntry.index();
                    double posRating = itemRatingEntry.get();
                    int negItemIdx = -1;
                    block3: do {
                        double sum = 0.0;
                        double randValue = Randoms.random();
                        for (Map.Entry mapEntry : this.itemProbs) {
                            int tempNegItemIdx = (Integer)mapEntry.getKey();
                            double prob = (Double)mapEntry.getValue();
                            if (!((sum += prob) >= randValue)) continue;
                            negItemIdx = tempNegItemIdx;
                            continue block3;
                        }
                    } while (itemRatingsVector.contains(negItemIdx));
                    double negRating = 0.0;
                    double posPredictRating = this.predict(userIdx, posItemIdx);
                    double negPredictRating = this.predict(userIdx, negItemIdx);
                    double distance = Math.sqrt(1.0 - Math.tanh(this.itemCorrs.get(posItemIdx, negItemIdx) * (double)this.simFilter));
                    double itemWeightValue = this.itemWeights.get(negItemIdx);
                    double error = itemWeightValue * (posPredictRating - negPredictRating - distance * (posRating - negRating));
                    this.loss += error * error;
                    double learnFactor = (double)this.learnRate * error;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                        double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                        double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                        this.userFactors.add(userIdx, factorIdx, -learnFactor * (posItemFactorValue - negItemFactorValue));
                        this.itemFactors.add(posItemIdx, factorIdx, -learnFactor * userFactorValue);
                        this.itemFactors.add(negItemIdx, factorIdx, learnFactor * userFactorValue);
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }
}

