/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.List;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gamma;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.ProbabilisticGraphicalRecommender;

@ModelData(value={"isRanking", "lda", "userTopicProbs", "topicItemProbs", "trainMatrix"})
public class LDARecommender
extends ProbabilisticGraphicalRecommender {
    protected float initAlpha;
    protected float initBeta;
    protected DenseMatrix topicItemNumbers;
    protected DenseMatrix userTopicNumbers;
    protected List<Integer> topicAssignments;
    protected DenseVector userTokenNumbers;
    protected DenseVector topicTokenNumbers;
    protected int numTopics;
    protected DenseVector alpha;
    protected DenseVector beta;
    protected DenseMatrix userTopicProbsSum;
    protected DenseMatrix topicItemProbsSum;
    protected DenseMatrix userTopicProbs;
    protected DenseMatrix topicItemProbs;
    protected int numStats = 0;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numTopics = this.conf.getInt("rec.topic.number", 10);
        this.userTopicProbsSum = new DenseMatrix(this.numUsers, this.numTopics);
        this.topicItemProbsSum = new DenseMatrix(this.numTopics, this.numItems);
        this.userTopicNumbers = new DenseMatrix(this.numUsers, this.numTopics);
        this.userTokenNumbers = new DenseVector(this.numUsers);
        this.topicItemNumbers = new DenseMatrix(this.numTopics, this.numItems);
        this.topicTokenNumbers = new DenseVector(this.numTopics);
        this.initAlpha = this.conf.getFloat("rec.user.dirichlet.prior", Float.valueOf(50.0f / (float)this.numTopics)).floatValue();
        this.initBeta = this.conf.getFloat("rec.topic.dirichlet.prior", Float.valueOf(0.01f)).floatValue();
        this.alpha = new DenseVector(this.numTopics);
        this.alpha.setAll(this.initAlpha);
        this.beta = new DenseVector(this.numItems);
        this.beta.setAll(this.initBeta);
        this.topicAssignments = new ArrayList<Integer>(this.trainMatrix.size());
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userIdx = matrixEntry.row();
            int itemIdx = matrixEntry.column();
            int num = (int)matrixEntry.get();
            for (int numIdx = 0; numIdx < num; ++numIdx) {
                int topicIdx = Randoms.uniform(this.numTopics);
                this.topicAssignments.add(topicIdx);
                this.userTopicNumbers.add(userIdx, topicIdx, 1.0);
                this.userTokenNumbers.add(userIdx, 1.0);
                this.topicItemNumbers.add(topicIdx, itemIdx, 1.0);
                this.topicTokenNumbers.add(topicIdx, 1.0);
            }
        }
    }

    @Override
    protected void eStep() {
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        int topicAssignmentsIdx = 0;
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userIdx = matrixEntry.row();
            int itemIdx = matrixEntry.column();
            int num = (int)matrixEntry.get();
            for (int numIdx = 0; numIdx < num; ++numIdx) {
                int topicIdx = this.topicAssignments.get(topicAssignmentsIdx);
                this.userTopicNumbers.add(userIdx, topicIdx, -1.0);
                this.userTokenNumbers.add(userIdx, -1.0);
                this.topicItemNumbers.add(topicIdx, itemIdx, -1.0);
                this.topicTokenNumbers.add(topicIdx, -1.0);
                double[] p = new double[this.numTopics];
                for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
                    p[topicIdx] = (this.userTopicNumbers.get(userIdx, topicIdx) + this.alpha.get(topicIdx)) / (this.userTokenNumbers.get(userIdx) + sumAlpha) * (this.topicItemNumbers.get(topicIdx, itemIdx) + this.beta.get(itemIdx)) / (this.topicTokenNumbers.get(topicIdx) + sumBeta);
                }
                for (topicIdx = 1; topicIdx < p.length; ++topicIdx) {
                    int n = topicIdx;
                    p[n] = p[n] + p[topicIdx - 1];
                }
                double rand = Randoms.uniform() * p[this.numTopics - 1];
                for (topicIdx = 0; topicIdx < p.length && !(rand < p[topicIdx]); ++topicIdx) {
                }
                this.userTopicNumbers.add(userIdx, topicIdx, 1.0);
                this.userTokenNumbers.add(userIdx, 1.0);
                this.topicItemNumbers.add(topicIdx, itemIdx, 1.0);
                this.topicTokenNumbers.add(topicIdx, 1.0);
                this.topicAssignments.set(topicAssignmentsIdx, topicIdx);
                ++topicAssignmentsIdx;
            }
        }
    }

    @Override
    protected void mStep() {
        double denominator;
        double numerator;
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            double topicAlpha = this.alpha.get(topicIdx);
            numerator = 0.0;
            denominator = 0.0;
            for (int itemIdx = 0; itemIdx < this.numUsers; ++itemIdx) {
                numerator += Gamma.digamma(this.userTopicNumbers.get(itemIdx, topicIdx) + topicAlpha) - Gamma.digamma(topicAlpha);
                denominator += Gamma.digamma(this.userTokenNumbers.get(itemIdx) + sumAlpha) - Gamma.digamma(sumAlpha);
            }
            if (numerator == 0.0) continue;
            this.alpha.set(topicIdx, topicAlpha * (numerator / denominator));
        }
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            double itemBeta = this.beta.get(itemIdx);
            numerator = 0.0;
            denominator = 0.0;
            for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
                numerator += Gamma.digamma(this.topicItemNumbers.get(topicIdx, itemIdx) + itemBeta) - Gamma.digamma(itemBeta);
                denominator += Gamma.digamma(this.topicTokenNumbers.get(topicIdx) + sumBeta) - Gamma.digamma(sumBeta);
            }
            if (numerator == 0.0) continue;
            this.beta.set(itemIdx, itemBeta * (numerator / denominator));
        }
    }

    @Override
    protected void readoutParams() {
        double val;
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            for (int factorIdx = 0; factorIdx < this.numTopics; ++factorIdx) {
                val = (this.userTopicNumbers.get(userIdx, factorIdx) + this.alpha.get(factorIdx)) / (this.userTokenNumbers.get(userIdx) + sumAlpha);
                this.userTopicProbsSum.add(userIdx, factorIdx, val);
            }
        }
        for (int factorIdx = 0; factorIdx < this.numTopics; ++factorIdx) {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                val = (this.topicItemNumbers.get(factorIdx, itemIdx) + this.beta.get(itemIdx)) / (this.topicTokenNumbers.get(factorIdx) + sumBeta);
                this.topicItemProbsSum.add(factorIdx, itemIdx, val);
            }
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.userTopicProbs = this.userTopicProbsSum.scale(1.0 / (double)this.numStats);
        this.topicItemProbs = this.topicItemProbsSum.scale(1.0 / (double)this.numStats);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return DenseMatrix.product(this.userTopicProbs, userIdx, this.topicItemProbs, itemIdx);
    }
}

