/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashMap;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gaussian;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.ProbabilisticGraphicalRecommender;

public class AspectModelRecommender
extends ProbabilisticGraphicalRecommender {
    protected DenseMatrix topicUserProbs;
    protected DenseMatrix topicUserProbsSum;
    protected DenseMatrix topicItemProbs;
    protected DenseMatrix topicItemProbsSum;
    protected DenseVector topicProbs;
    protected DenseVector topicProbsSum;
    protected DenseVector topicProbsMean;
    protected DenseVector topicProbsMeanSum;
    protected DenseVector topicProbsVariance;
    protected DenseVector topicProbsVarianceSum;
    protected int numTopics;
    protected static double smallValue = 1.0E-7;
    protected Table<Integer, Integer, Map<Integer, Double>> Q;

    @Override
    protected void setup() throws LibrecException {
        int z;
        super.setup();
        this.numTopics = this.conf.getInt("rec.factory.number", 10);
        this.topicProbs = new DenseVector(this.numTopics);
        this.topicProbsSum = new DenseVector(this.numTopics);
        double[] probs = Randoms.randProbs(this.numTopics);
        for (z = 0; z < this.numTopics; ++z) {
            this.topicProbs.set(z, probs[z]);
        }
        this.topicUserProbs = new DenseMatrix(this.numTopics, this.numUsers);
        this.topicUserProbsSum = new DenseMatrix(this.numTopics, this.numUsers);
        for (z = 0; z < this.numTopics; ++z) {
            probs = Randoms.randProbs(this.numUsers);
            for (int u = 0; u < this.numUsers; ++u) {
                this.topicUserProbs.set(z, u, probs[u]);
            }
        }
        this.topicItemProbs = new DenseMatrix(this.numTopics, this.numItems);
        this.topicItemProbsSum = new DenseMatrix(this.numTopics, this.numItems);
        for (z = 0; z < this.numTopics; ++z) {
            probs = Randoms.randProbs(this.numItems);
            for (int i = 0; i < this.numItems; ++i) {
                this.topicItemProbs.set(z, i, probs[i]);
            }
        }
        this.Q = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            this.Q.put(u, i, new HashMap());
        }
        double globalMean = this.trainMatrix.mean();
        this.topicProbsMean = new DenseVector(this.numTopics);
        this.topicProbsVariance = new DenseVector(this.numTopics);
        this.topicProbsMeanSum = new DenseVector(this.numTopics);
        this.topicProbsVarianceSum = new DenseVector(this.numTopics);
        for (int z2 = 0; z2 < this.numTopics; ++z2) {
            this.topicProbsMean.set(z2, globalMean);
            this.topicProbsVariance.set(z2, 2.0);
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        super.trainModel();
    }

    @Override
    protected void eStep() {
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double r = me.get();
            double denominator = 0.0;
            double[] numerator = new double[this.numTopics];
            for (int z = 0; z < this.numTopics; ++z) {
                double val;
                numerator[z] = val = this.topicProbs.get(z) * this.topicUserProbs.get(z, u) * this.topicItemProbs.get(z, i) * Gaussian.pdf(r, this.topicProbsMean.get(z), this.topicProbsVariance.get(z));
                denominator += val;
            }
            Map<Integer, Double> QTopicProbs = this.Q.get(u, i);
            for (int z = 0; z < this.numTopics; ++z) {
                double prob = denominator > 0.0 ? numerator[z] / denominator : 0.0;
                QTopicProbs.put(z, prob);
            }
        }
    }

    @Override
    protected void mStep() {
        this.topicProbsSum.setAll(0.0);
        this.topicUserProbsSum.setAll(0.0);
        this.topicItemProbsSum.setAll(0.0);
        this.topicProbsMeanSum.setAll(0.0);
        this.topicProbsVarianceSum.setAll(0.0);
        for (int z = 0; z < this.numTopics; ++z) {
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int i = me.column();
                double r = me.get();
                double val = this.Q.get(u, i).get(z);
                this.topicProbsSum.add(z, val);
                this.topicUserProbsSum.add(z, u, val);
                this.topicItemProbsSum.add(z, i, val);
                this.topicProbsMeanSum.add(z, r * val);
            }
            this.topicProbsSum.add(z, smallValue);
            this.topicProbs.set(z, this.topicProbsSum.get(z) / (double)this.numRates);
            for (int u = 0; u < this.numUsers; ++u) {
                this.topicUserProbs.set(z, u, this.topicUserProbsSum.get(z, u) / this.topicProbsSum.get(z));
            }
            for (int i = 0; i < this.numItems; ++i) {
                this.topicItemProbs.set(z, i, this.topicItemProbsSum.get(z, i) / this.topicProbsSum.get(z));
            }
            double mean = this.topicProbsMeanSum.get(z) / this.topicProbsSum.get(z);
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int i = me.column();
                double r = me.get();
                double val = this.Q.get(u, i).get(z);
                this.topicProbsVarianceSum.add(z, (r - mean) * (r - mean) * val);
            }
            this.topicProbsMean.set(z, mean);
            this.topicProbsVariance.set(z, (this.topicProbsVarianceSum.get(z) + smallValue) / this.topicProbsSum.get(z));
        }
    }

    @Override
    protected void readoutParams() {
    }

    @Override
    protected void estimateParams() {
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double predictRating = 0.0;
        double denominator = 0.0;
        for (int z = 0; z < this.numTopics; ++z) {
            double weight = this.topicProbs.get(z) * this.topicUserProbs.get(z, userIdx) * this.topicItemProbs.get(z, itemIdx);
            denominator += weight;
            predictRating += weight * this.topicProbsMean.get(z);
        }
        return predictRating /= denominator;
    }
}

