/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.ProbabilisticGraphicalRecommender;

public class AspectModelRecommender
extends ProbabilisticGraphicalRecommender {
    protected int numTopics;
    protected DenseMatrix topicUserProbs;
    protected DenseMatrix topicUserProbsSum;
    protected DenseMatrix topicItemProbs;
    protected DenseMatrix topicItemProbsSum;
    protected DenseVector topicProbs;
    protected DenseVector topicProbsSum;
    protected Table<Integer, Integer, double[]> entryTopicDistribution;

    @Override
    protected void setup() throws LibrecException {
        int topicIdx;
        super.setup();
        this.numTopics = this.conf.getInt("rec.topic.number", 10);
        this.isRanking = true;
        this.topicProbs = new DenseVector(this.numTopics);
        this.topicProbsSum = new DenseVector(this.numTopics);
        double[] probs = Randoms.randProbs(this.numTopics);
        for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            this.topicProbs.set(topicIdx, probs[topicIdx]);
        }
        this.topicUserProbs = new DenseMatrix(this.numTopics, this.numUsers);
        this.topicUserProbsSum = new DenseMatrix(this.numTopics, this.numUsers);
        for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            probs = Randoms.randProbs(this.numUsers);
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                this.topicUserProbs.set(topicIdx, userIdx, probs[userIdx]);
            }
        }
        this.topicItemProbs = new DenseMatrix(this.numTopics, this.numItems);
        this.topicItemProbsSum = new DenseMatrix(this.numTopics, this.numItems);
        for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            probs = Randoms.randProbs(this.numItems);
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                this.topicItemProbs.set(topicIdx, itemIdx, probs[itemIdx]);
            }
        }
        this.entryTopicDistribution = HashBasedTable.create();
        for (MatrixEntry trainMatrixEntry : this.trainMatrix) {
            int userIdx = trainMatrixEntry.row();
            int itemIdx = trainMatrixEntry.column();
            this.entryTopicDistribution.put(userIdx, itemIdx, new double[this.numTopics]);
        }
    }

    @Override
    protected void eStep() {
        for (MatrixEntry trainMatrixEntry : this.trainMatrix) {
            int topicIdx;
            int userIdx = trainMatrixEntry.row();
            int itemIdx = trainMatrixEntry.column();
            double[] entryTopicProbs = this.entryTopicDistribution.get(userIdx, itemIdx);
            double sum = 0.0;
            for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
                double prob;
                entryTopicProbs[topicIdx] = prob = this.topicUserProbs.get(topicIdx, userIdx) * this.topicItemProbs.get(topicIdx, itemIdx) * this.topicProbs.get(topicIdx);
                sum += prob;
            }
            topicIdx = 0;
            while (topicIdx < this.numTopics) {
                int n = topicIdx++;
                entryTopicProbs[n] = entryTopicProbs[n] / sum;
            }
        }
    }

    @Override
    protected void mStep() {
        int topicIdx;
        this.topicProbsSum.setAll(0.0);
        this.topicUserProbsSum.setAll(0.0);
        this.topicItemProbsSum.setAll(0.0);
        for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            for (MatrixEntry trainMatrixEntry : this.trainMatrix) {
                int userIdx = trainMatrixEntry.row();
                int itemIdx = trainMatrixEntry.column();
                double num = trainMatrixEntry.get();
                double val = this.entryTopicDistribution.get(userIdx, itemIdx)[topicIdx] * num;
                this.topicProbsSum.add(topicIdx, val);
                this.topicUserProbsSum.add(topicIdx, userIdx, val);
                this.topicItemProbsSum.add(topicIdx, itemIdx, val);
            }
        }
        this.topicProbs = this.topicProbsSum.scale(1.0 / this.topicProbsSum.sum());
        for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            double userProbsSum = this.topicUserProbs.sumOfColumn(topicIdx);
            userProbsSum = userProbsSum > 0.0 ? userProbsSum : 1.0;
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                this.topicUserProbs.set(topicIdx, userIdx, this.topicUserProbsSum.get(topicIdx, userIdx) / userProbsSum);
            }
            double itemProbsSum = this.topicItemProbs.sumOfColumn(topicIdx);
            itemProbsSum = itemProbsSum > 0.0 ? itemProbsSum : 1.0;
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                this.topicItemProbs.set(topicIdx, itemIdx, this.topicItemProbsSum.get(topicIdx, itemIdx) / itemProbsSum);
            }
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double predictRating = 0.0;
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            predictRating += this.topicUserProbs.get(topicIdx, userIdx) * this.topicItemProbs.get(topicIdx, itemIdx) * this.topicProbs.get(topicIdx);
        }
        return predictRating;
    }
}

