/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gaussian;
import net.librec.math.algorithm.Randoms;
import net.librec.math.algorithm.Stats;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.ProbabilisticGraphicalRecommender;

public class GPLSARecommender
extends ProbabilisticGraphicalRecommender {
    protected int numTopics;
    protected Table<Integer, Integer, Map<Integer, Double>> Q;
    protected DenseMatrix userTopicProbs;
    protected DenseMatrix topicItemMu;
    protected DenseMatrix topicItemSigma;
    protected DenseVector userMu;
    protected DenseVector userSigma;
    protected float smoothWeight;
    protected float b;
    protected static double smallValue = 0.01;

    @Override
    protected void setup() throws LibrecException {
        double sum;
        int numRatings;
        super.setup();
        this.numTopics = this.conf.getInt("rec.topic.number", 10);
        this.userTopicProbs = new DenseMatrix(this.numUsers, this.numTopics);
        for (int u = 0; u < this.numUsers; ++u) {
            double[] probs = Randoms.randProbs(this.numTopics);
            for (int z = 0; z < this.numTopics; ++z) {
                this.userTopicProbs.set(u, z, probs[z]);
            }
        }
        double mean = this.trainMatrix.mean();
        double sd = Stats.sd(this.trainMatrix.getData(), mean);
        this.userMu = new DenseVector(this.numUsers);
        this.userSigma = new DenseVector(this.numUsers);
        this.smoothWeight = this.conf.getInt("rec.recommender.smoothWeight").intValue();
        for (int u = 0; u < this.numUsers; ++u) {
            SparseVector userRow = this.trainMatrix.row(u);
            numRatings = userRow.size();
            if (numRatings < 1) continue;
            double mu_u = (userRow.sum() + (double)this.smoothWeight * mean) / (double)((float)numRatings + this.smoothWeight);
            this.userMu.set(u, mu_u);
            sum = 0.0;
            for (VectorEntry ve : userRow) {
                sum += Math.pow(ve.get() - mu_u, 2.0);
            }
            double sigma_u = Math.sqrt((sum += (double)this.smoothWeight * Math.pow(sd, 2.0)) / (double)((float)numRatings + this.smoothWeight));
            this.userSigma.set(u, sigma_u);
        }
        this.Q = HashBasedTable.create();
        for (MatrixEntry trainMatrixEntry : this.trainMatrix) {
            int userIdx = trainMatrixEntry.row();
            int itemIdx = trainMatrixEntry.column();
            double rating = trainMatrixEntry.get();
            double r = (rating - this.userMu.get(userIdx)) / this.userSigma.get(userIdx);
            this.trainMatrix.set(userIdx, itemIdx, r);
            this.Q.put(userIdx, itemIdx, new HashMap());
        }
        this.topicItemMu = new DenseMatrix(this.numItems, this.numTopics);
        this.topicItemSigma = new DenseMatrix(this.numItems, this.numTopics);
        for (int i = 0; i < this.numItems; ++i) {
            SparseVector itemColumn = this.trainMatrix.column(i);
            numRatings = itemColumn.size();
            if (numRatings < 1) continue;
            double mu_i = itemColumn.mean();
            sum = 0.0;
            for (VectorEntry ve : itemColumn) {
                sum += Math.pow(ve.get() - mu_i, 2.0);
            }
            double sd_i = Math.sqrt(sum / (double)numRatings);
            for (int z = 0; z < this.numTopics; ++z) {
                this.topicItemMu.set(i, z, mu_i + smallValue * Randoms.uniform());
                this.topicItemSigma.set(i, z, sd_i + smallValue * Randoms.uniform());
            }
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        super.trainModel();
    }

    @Override
    protected void eStep() {
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rating = me.get();
            double denominator = 0.0;
            double[] numerator = new double[this.numTopics];
            for (int z = 0; z < this.numTopics; ++z) {
                double val;
                double pdf = Gaussian.pdf(rating, this.topicItemMu.get(i, z), this.topicItemSigma.get(i, z));
                numerator[z] = val = Math.pow(this.userTopicProbs.get(u, z) * pdf, this.b);
                denominator += val;
            }
            Map<Integer, Double> factorProbs = this.Q.get(u, i);
            for (int z = 0; z < this.numTopics; ++z) {
                double prob = denominator > 0.0 ? numerator[z] / denominator : 0.0;
                factorProbs.put(z, prob);
            }
        }
    }

    @Override
    protected void mStep() {
        for (int u = 0; u < this.numUsers; ++u) {
            int z;
            List<Integer> items = this.trainMatrix.getColumns(u);
            if (items.size() < 1) continue;
            double[] numerator = new double[this.numTopics];
            double denominator = 0.0;
            for (z = 0; z < this.numTopics; ++z) {
                for (int i : items) {
                    numerator[z] = this.Q.get(u, i).get(z);
                }
                denominator += numerator[z];
            }
            for (z = 0; z < this.numTopics; ++z) {
                this.userTopicProbs.set(u, z, numerator[z] / denominator);
            }
        }
        for (int i = 0; i < this.numItems; ++i) {
            List<Integer> users = this.trainMatrix.getRows(i);
            if (users.size() < 1) continue;
            for (int z = 0; z < this.numTopics; ++z) {
                double numerator = 0.0;
                double denominator = 0.0;
                for (int u : users) {
                    double rating = this.trainMatrix.get(u, i);
                    double prob = this.Q.get(u, i).get(z);
                    numerator += rating * prob;
                    denominator += prob;
                }
                double mu = denominator > 0.0 ? numerator / denominator : 0.0;
                this.topicItemMu.set(i, z, mu);
                numerator = 0.0;
                denominator = 0.0;
                for (int u : users) {
                    double rating = this.trainMatrix.get(u, i);
                    double prob = this.Q.get(u, i).get(z);
                    numerator += Math.pow(rating - mu, 2.0) * prob;
                    denominator += prob;
                }
                double sigma = denominator > 0.0 ? Math.sqrt(numerator / denominator) : 0.0;
                this.topicItemSigma.set(i, z, sigma);
            }
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double sum = 0.0;
        for (int z = 0; z < this.numTopics; ++z) {
            sum += this.userTopicProbs.get(userIdx, z) * this.topicItemMu.get(itemIdx, z);
        }
        double predictRating = this.userMu.get(userIdx) + this.userSigma.get(userIdx) * sum;
        return predictRating;
    }
}

