/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.ProbabilisticGraphicalRecommender;

public class LDCCRecommender
extends ProbabilisticGraphicalRecommender {
    private Table<Integer, Integer, Integer> userTopics;
    private Table<Integer, Integer, Integer> itemTopics;
    private DenseMatrix numEachUserTopics;
    private DenseMatrix numEachItemTopics;
    private DenseVector numEachUserRatings;
    private DenseVector numEachItemRatings;
    private DenseMatrix numUserItemTopics;
    private int[][][] numUserItemRatingTopics;
    private int numUserTopics;
    private int numItemTopics;
    private double userAlpha;
    private double itemAlpha;
    private double ratingBeta;
    private DenseMatrix userTopicProbs;
    private DenseMatrix itemTopicProbs;
    private DenseMatrix userTopicProbsSum;
    private DenseMatrix itemTopicProbsSum;
    private double[][][] userItemRatingTopicProbs;
    private double[][][] userItemRatingTopicProbsSum;
    private int numRatingLevels;
    private List<Double> ratingScale;
    private int numStats;
    private double loss;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numStats = 0;
        this.numUserTopics = this.conf.getInt("rec.pgm.number.users", 10);
        this.numItemTopics = this.conf.getInt("rec.pgm.number.items", 10);
        this.burnIn = this.conf.getInt("rec.pgm.burn-in", 100);
        Set<Double> ratingScaleSet = this.trainMatrix.getValueSet();
        this.ratingScale = new ArrayList<Double>(ratingScaleSet);
        this.numRatingLevels = this.ratingScale.size();
        this.userAlpha = this.conf.getDouble("rec.pgm.user.alpha", 1.0 / (double)this.numUserTopics);
        this.itemAlpha = this.conf.getDouble("rec.pgm.item.alpha", 1.0 / (double)this.numItemTopics);
        this.ratingBeta = this.conf.getDouble("rec.pgm.rating.beta", 1.0 / (double)this.numRatingLevels);
        this.numEachUserTopics = new DenseMatrix(this.numUsers, this.numUserTopics);
        this.numEachItemTopics = new DenseMatrix(this.numItems, this.numItemTopics);
        this.numEachUserRatings = new DenseVector(this.numUsers);
        this.numEachItemRatings = new DenseVector(this.numItems);
        this.numUserItemRatingTopics = new int[this.numUserTopics][this.numItemTopics][this.numRatingLevels];
        this.numUserItemTopics = new DenseMatrix(this.numUserTopics, this.numItemTopics);
        this.userTopics = HashBasedTable.create();
        this.itemTopics = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int v = me.column();
            double rating = me.get();
            int r = this.ratingScale.indexOf(rating);
            int i = (int)((double)this.numUserTopics * Randoms.uniform());
            int j = (int)((double)this.numItemTopics * Randoms.uniform());
            this.numEachUserTopics.add(u, i, 1.0);
            this.numEachUserRatings.add(u, 1.0);
            this.numEachItemTopics.add(v, j, 1.0);
            this.numEachItemRatings.add(v, 1.0);
            int[] nArray = this.numUserItemRatingTopics[i][j];
            int n = r;
            nArray[n] = nArray[n] + 1;
            this.numUserItemTopics.add(i, j, 1.0);
            this.userTopics.put(u, v, i);
            this.itemTopics.put(u, v, j);
        }
        this.userTopicProbsSum = new DenseMatrix(this.numUsers, this.numUserTopics);
        this.itemTopicProbsSum = new DenseMatrix(this.numItems, this.numItemTopics);
        this.userItemRatingTopicProbs = new double[this.numUserTopics][this.numItemTopics][this.numRatingLevels];
        this.userItemRatingTopicProbsSum = new double[this.numUserTopics][this.numItemTopics][this.numRatingLevels];
    }

    @Override
    protected void eStep() {
        for (MatrixEntry me : this.trainMatrix) {
            int n;
            int m;
            int u = me.row();
            int v = me.column();
            double rating = me.get();
            int r = this.ratingScale.indexOf(rating);
            int i = this.userTopics.get(u, v);
            int j = this.itemTopics.get(u, v);
            this.numEachUserTopics.add(u, i, -1.0);
            this.numEachUserRatings.add(u, -1.0);
            this.numEachItemTopics.add(v, j, -1.0);
            this.numEachItemRatings.add(v, -1.0);
            int[] nArray = this.numUserItemRatingTopics[i][j];
            int n2 = r;
            nArray[n2] = nArray[n2] - 1;
            this.numUserItemTopics.add(i, j, -1.0);
            DenseMatrix probs = new DenseMatrix(this.numUserTopics, this.numItemTopics);
            double sum = 0.0;
            for (int m2 = 0; m2 < this.numUserTopics; ++m2) {
                for (int n3 = 0; n3 < this.numItemTopics; ++n3) {
                    double v1 = (this.numEachUserTopics.get(u, m2) + this.userAlpha) / (this.numEachUserRatings.get(u) + (double)this.numUserTopics * this.userAlpha);
                    double v2 = (this.numEachUserTopics.get(i, n3) + this.itemAlpha) / (this.numEachItemRatings.get(v) + (double)this.numItemTopics * this.itemAlpha);
                    double v3 = ((double)this.numUserItemRatingTopics[m2][n3][r] + this.ratingBeta) / (this.numUserItemTopics.get(m2, n3) + (double)this.numRatingLevels * this.ratingBeta);
                    double prob = v1 * v2 * v3;
                    probs.set(m2, n3, prob);
                    sum += prob;
                }
            }
            probs = probs.scale(1.0 / sum);
            double[] Pu = new double[this.numUserTopics];
            for (m = 0; m < this.numUserTopics; ++m) {
                Pu[m] = probs.sumOfRow(m);
            }
            for (m = 1; m < this.numUserTopics; ++m) {
                int n4 = m;
                Pu[n4] = Pu[n4] + Pu[m - 1];
            }
            double rand = Randoms.uniform();
            for (i = 0; i < this.numUserTopics && !(rand < Pu[i]); ++i) {
            }
            double[] Pv = new double[this.numItemTopics];
            for (n = 0; n < this.numItemTopics; ++n) {
                Pv[n] = probs.sumOfColumn(n);
            }
            for (n = 1; n < this.numItemTopics; ++n) {
                int n5 = n;
                Pv[n5] = Pv[n5] + Pv[n - 1];
            }
            rand = Randoms.uniform();
            for (j = 0; j < this.numItemTopics && !(rand < Pv[j]); ++j) {
            }
            this.numEachUserTopics.add(u, i, 1.0);
            this.numEachUserRatings.add(u, 1.0);
            this.numEachItemTopics.add(v, j, 1.0);
            this.numEachItemRatings.add(v, 1.0);
            int[] nArray2 = this.numUserItemRatingTopics[i][j];
            int n6 = r;
            nArray2[n6] = nArray2[n6] + 1;
            this.numUserItemTopics.add(i, j, 1.0);
            this.userTopics.put(u, v, i);
            this.itemTopics.put(u, v, j);
        }
    }

    @Override
    protected void mStep() {
    }

    @Override
    protected void readoutParams() {
        int j;
        for (int u = 0; u < this.numUsers; ++u) {
            for (int i = 0; i < this.numUserTopics; ++i) {
                this.userTopicProbsSum.add(u, i, (this.numEachUserTopics.get(u, i) + this.userAlpha) / (this.numEachUserRatings.get(u) + (double)this.numUserTopics * this.userAlpha));
            }
        }
        for (int v = 0; v < this.numItems; ++v) {
            for (j = 0; j < this.numItemTopics; ++j) {
                this.itemTopicProbsSum.add(v, j, (this.numEachItemTopics.get(v, j) + this.itemAlpha) / (this.numEachItemRatings.get(v) + (double)this.numItemTopics * this.itemAlpha));
            }
        }
        for (int i = 0; i < this.numUserTopics; ++i) {
            for (j = 0; j < this.numItemTopics; ++j) {
                for (int r = 0; r < this.numRatingLevels; ++r) {
                    double[] dArray = this.userItemRatingTopicProbsSum[i][j];
                    int n = r;
                    dArray[n] = dArray[n] + ((double)this.numUserItemRatingTopics[i][j][r] + this.ratingBeta) / (this.numUserItemTopics.get(i, j) + (double)this.numRatingLevels * this.ratingBeta);
                }
            }
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.userTopicProbs = this.userTopicProbsSum.scale(1.0 / (double)this.numStats);
        this.itemTopicProbs = this.itemTopicProbsSum.scale(1.0 / (double)this.numStats);
        for (int i = 0; i < this.numUserTopics; ++i) {
            for (int j = 0; j < this.numItemTopics; ++j) {
                for (int r = 0; r < this.numRatingLevels; ++r) {
                    this.userItemRatingTopicProbs[i][j][r] = this.userItemRatingTopicProbsSum[i][j][r] / (double)this.numStats;
                }
            }
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double pred = 0.0;
        for (int l = 0; l < this.numRatingLevels; ++l) {
            double rate = this.ratingScale.get(l);
            double prob = 0.0;
            for (int i = 0; i < this.numUserTopics; ++i) {
                for (int j = 0; j < this.numItemTopics; ++j) {
                    prob += this.userItemRatingTopicProbs[i][j][l] * this.userTopicProbs.get(userIdx, i) * this.itemTopicProbs.get(itemIdx, j);
                }
            }
            pred += rate * prob;
        }
        return pred;
    }

    @Override
    protected boolean isConverged(int iter) {
        this.estimateParams();
        int N = 0;
        double sum = 0.0;
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int v = me.column();
            double rating = me.get();
            sum += this.perplexity(u, v, rating);
            ++N;
        }
        double perp = Math.exp(sum / (double)N);
        double delta = perp - this.loss;
        if (this.numStats > 1 && delta > 0.0) {
            return true;
        }
        this.loss = perp;
        return false;
    }

    protected double perplexity(int user, int item, double rating) {
        int r = (int)(rating / this.minRate - 1.0);
        double prob = 0.0;
        for (int i = 0; i < this.numUserTopics; ++i) {
            for (int j = 0; j < this.numItemTopics; ++j) {
                prob += this.userItemRatingTopicProbs[i][j][r] * this.userTopicProbs.get(user, i) * this.itemTopicProbs.get(item, j);
            }
        }
        return -Math.log(prob);
    }
}

