/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.ranking;

import com.google.common.cache.LoadingCache;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRanking", "sbpr", "userFactors", "itemFactors", "itemBiases"})
public class SBPRRecommender
extends SocialRecommender {
    private DenseVector itemBiases;
    protected float regBias;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected static String cacheSpec;
    private List<List<Integer>> userSocialItemsSetList;

    @Override
    public void setup() throws LibrecException {
        int userIdx;
        super.setup();
        this.regBias = this.conf.getFloat("rec.bias.regularization", Float.valueOf(0.01f)).floatValue();
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=5000,expireAfterAccess=50m");
        this.itemBiases = new DenseVector(this.numItems);
        this.itemBiases.init();
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.userSocialItemsSetList = new ArrayList<List<Integer>>(this.numUsers);
        for (userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            this.userSocialItemsSetList.add(new ArrayList());
        }
        for (userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            List<Integer> uRatedItems = null;
            try {
                uRatedItems = this.userItemsCache.get(userIdx);
            }
            catch (ExecutionException e) {
                e.printStackTrace();
            }
            if (uRatedItems.size() == 0) continue;
            List<Integer> trustedUsers = this.socialMatrix.getColumns(userIdx);
            ArrayList<Integer> items = new ArrayList<Integer>();
            for (int trustedUserIdx : trustedUsers) {
                List<Integer> trustedRatedItems = null;
                try {
                    trustedRatedItems = this.userItemsCache.get(trustedUserIdx);
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
                Iterator iterator = trustedRatedItems.iterator();
                while (iterator.hasNext()) {
                    int trustedRatedItemIdx = (Integer)iterator.next();
                    if (uRatedItems.contains(trustedRatedItemIdx) || items.contains(trustedRatedItemIdx)) continue;
                    items.add(trustedRatedItemIdx);
                }
            }
            this.userSocialItemsSetList.set(userIdx, items);
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            int smax = this.numUsers * 100;
            for (int sample = 0; sample < smax; ++sample) {
                int negItemIdx;
                int userIdx;
                List<Integer> ratedItems = null;
                do {
                    userIdx = Randoms.uniform(this.trainMatrix.numRows());
                    try {
                        ratedItems = this.userItemsCache.get(userIdx);
                    }
                    catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                } while (ratedItems.size() == 0);
                int posItemIdx = Randoms.random(ratedItems);
                double posPredictRating = this.predict(userIdx, posItemIdx);
                List<Integer> socialItemsList = this.userSocialItemsSetList.get(userIdx);
                while (ratedItems.contains(negItemIdx = Randoms.uniform(this.numItems)) || socialItemsList.contains(negItemIdx)) {
                }
                double negPredictRating = this.predict(userIdx, negItemIdx);
                if (socialItemsList.size() > 0) {
                    int socialItemIdx = Randoms.random(socialItemsList);
                    double socialPredictRating = this.predict(userIdx, socialItemIdx);
                    SparseVector trustedUsersVector = this.socialMatrix.row(userIdx);
                    double socialWeight = 0.0;
                    for (VectorEntry trustedVectorEntry : trustedUsersVector) {
                        double socialRating;
                        int trustedUserIdx = trustedVectorEntry.index();
                        if (trustedUserIdx >= this.trainMatrix.numRows() || !((socialRating = this.trainMatrix.get(trustedUserIdx, socialItemIdx)) > 0.0)) continue;
                        socialWeight += 1.0;
                    }
                    double posSocialDiffValue = (posPredictRating - socialPredictRating) / (1.0 + socialWeight);
                    double socialNegDiffValue = socialPredictRating - negPredictRating;
                    double error = -Math.log(Maths.logistic(posSocialDiffValue)) - Math.log(Maths.logistic(socialNegDiffValue));
                    this.loss += error;
                    double posSocialGradient = Maths.logistic(-posSocialDiffValue);
                    double socialNegGradient = Maths.logistic(-socialNegDiffValue);
                    double posItemBiasValue = this.itemBiases.get(posItemIdx);
                    this.itemBiases.add(posItemIdx, (double)this.learnRate * (posSocialGradient / (1.0 + socialWeight) - (double)this.regBias * posItemBiasValue));
                    this.loss += (double)this.regBias * posItemBiasValue * posItemBiasValue;
                    double socialItemBiasValue = this.itemBiases.get(socialItemIdx);
                    this.itemBiases.add(socialItemIdx, (double)this.learnRate * (-posSocialGradient / (1.0 + socialWeight) + socialNegGradient - (double)this.regBias * socialItemBiasValue));
                    this.loss += (double)this.regBias * socialItemBiasValue * socialItemBiasValue;
                    double negItemBiasValue = this.itemBiases.get(negItemIdx);
                    this.itemBiases.add(negItemIdx, (double)this.learnRate * (-socialNegGradient - (double)this.regBias * negItemBiasValue));
                    this.loss += (double)this.regBias * negItemBiasValue * negItemBiasValue;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                        double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                        double socialItemFactorValue = this.itemFactors.get(socialItemIdx, factorIdx);
                        double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                        double delta_puf = posSocialGradient * (posItemFactorValue - socialItemFactorValue) / (1.0 + socialWeight) + socialNegGradient * (socialItemFactorValue - negItemFactorValue);
                        this.userFactors.add(userIdx, factorIdx, (double)this.learnRate * (delta_puf - (double)this.regUser * userFactorValue));
                        this.itemFactors.add(posItemIdx, factorIdx, (double)this.learnRate * (posSocialGradient * userFactorValue / (1.0 + socialWeight) - (double)this.regItem * posItemFactorValue));
                        double delta_qkf = posSocialGradient * (-userFactorValue / (1.0 + socialWeight)) + socialNegGradient * userFactorValue;
                        this.itemFactors.add(socialItemIdx, factorIdx, (double)this.learnRate * (delta_qkf - (double)this.regItem * socialItemFactorValue));
                        this.itemFactors.add(negItemIdx, factorIdx, (double)this.learnRate * (socialNegGradient * -userFactorValue - (double)this.regItem * negItemFactorValue));
                        this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * posItemFactorValue * posItemFactorValue + (double)this.regItem * negItemFactorValue * negItemFactorValue + (double)this.regItem * socialItemFactorValue * socialItemFactorValue;
                    }
                    continue;
                }
                double posNegDiffValue = posPredictRating - negPredictRating;
                this.loss += posNegDiffValue;
                double posNegGradient = Maths.logistic(-posNegDiffValue);
                double posItemBiasValue = this.itemBiases.get(posItemIdx);
                this.itemBiases.add(posItemIdx, (double)this.learnRate * (posNegGradient - (double)this.regBias * posItemBiasValue));
                this.loss += (double)this.regBias * posItemBiasValue * posItemBiasValue;
                double negItemBiasValue = this.itemBiases.get(negItemIdx);
                this.itemBiases.add(negItemIdx, (double)this.learnRate * (-posNegGradient - (double)this.regBias * negItemBiasValue));
                this.loss += (double)this.regBias * negItemBiasValue * negItemBiasValue;
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                    double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                    this.userFactors.add(userIdx, factorIdx, (double)this.learnRate * (posNegGradient * (posItemFactorValue - negItemFactorValue) - (double)this.regUser * userFactorValue));
                    this.itemFactors.add(posItemIdx, factorIdx, (double)this.learnRate * (posNegGradient * userFactorValue - (double)this.regItem * posItemFactorValue));
                    this.itemFactors.add(negItemIdx, factorIdx, (double)this.learnRate * (posNegGradient * -userFactorValue - (double)this.regItem * negItemFactorValue));
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * posItemFactorValue * posItemFactorValue + (double)this.regItem * negItemFactorValue * negItemFactorValue;
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double predictRating = this.itemBiases.get(itemIdx) + DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx);
        return predictRating;
    }
}

