/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import java.util.ArrayList;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

public class BPMFRecommender
extends MatrixFactorizationRecommender {
    private double userMu0;
    private double userBeta0;
    private double userWishartScale0;
    private double itemMu0;
    private double itemBeta0;
    private double itemWishartScale0;
    private DenseVector userMu;
    private DenseVector itemMu;
    private DenseMatrix userWishartScale;
    private DenseMatrix itemWishartScale;
    private double userBeta;
    private double itemBeta;
    private double userWishartNu;
    private double itemWishartNu;
    private double ratingSigma;
    private SparseMatrix predictMatrix;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.userMu0 = this.conf.getDouble("rec.recommender.user.mu", 0.0);
        this.userBeta0 = this.conf.getDouble("rec.recommender.user.beta", 1.0);
        this.userWishartScale0 = this.conf.getDouble("rec.recommender.user.wishart.scale", 1.0);
        this.itemMu0 = this.conf.getDouble("rec.recommender.item.mu", 0.0);
        this.itemBeta0 = this.conf.getDouble("rec.recommender.item.beta", 1.0);
        this.itemWishartScale0 = this.conf.getDouble("rec.recommender.item.wishart.scale", 1.0);
        this.ratingSigma = this.conf.getDouble("rec.recommender.rating.sigma", 2.0);
    }

    protected void initModel() throws LibrecException {
        this.userMu = new DenseVector(this.numFactors);
        this.userMu.setAll(this.userMu0);
        this.itemMu = new DenseVector(this.numFactors);
        this.itemMu.setAll(this.itemMu0);
        this.userBeta = this.userBeta0;
        this.itemBeta = this.itemBeta0;
        this.userWishartScale = new DenseMatrix(this.numFactors, this.numFactors);
        this.itemWishartScale = new DenseMatrix(this.numFactors, this.numFactors);
        for (int i = 0; i < this.numFactors; ++i) {
            this.userWishartScale.set(i, i, this.userWishartScale0);
            this.itemWishartScale.set(i, i, this.itemWishartScale0);
        }
        this.userWishartScale.inv();
        this.itemWishartScale.inv();
        this.userWishartNu = this.numFactors;
        this.itemWishartNu = this.numFactors;
        this.predictMatrix = new SparseMatrix(this.testMatrix);
    }

    @Override
    protected void trainModel() throws LibrecException {
        this.initModel();
        ArrayList<SparseVector> userTrainVectors = new ArrayList<SparseVector>(this.numUsers);
        ArrayList<SparseVector> itemTrainVectors = new ArrayList<SparseVector>(this.numItems);
        for (int u = 0; u < this.numUsers; ++u) {
            userTrainVectors.add(this.trainMatrix.row(u));
        }
        for (int i = 0; i < this.numItems; ++i) {
            itemTrainVectors.add(this.trainMatrix.column(i));
        }
        DenseVector mu_u = new DenseVector(this.numFactors);
        DenseVector mu_m = new DenseVector(this.numFactors);
        for (int f = 0; f < this.numFactors; ++f) {
            mu_u.set(f, this.userFactors.columnMean(f));
            mu_m.set(f, this.itemFactors.columnMean(f));
        }
        DenseMatrix variance_u = this.userFactors.cov().inv();
        DenseMatrix variance_m = this.itemFactors.cov().inv();
        HyperParameters userHyperParameters = new HyperParameters(mu_u, variance_u);
        HyperParameters itemHyperParameters = new HyperParameters(mu_m, variance_m);
        for (int iter = 0; iter < this.numIterations; ++iter) {
            int startnum;
            userHyperParameters = this.samplingHyperParameters(userHyperParameters, this.userFactors, this.userMu, this.userBeta, this.userWishartScale, this.userWishartNu);
            itemHyperParameters = this.samplingHyperParameters(itemHyperParameters, this.itemFactors, this.itemMu, this.itemBeta, this.itemWishartScale, this.itemWishartNu);
            for (int gibbsIteration = 0; gibbsIteration < 1; ++gibbsIteration) {
                int count;
                SparseVector ratings;
                for (int u = 0; u < this.numUsers; ++u) {
                    ratings = (SparseVector)userTrainVectors.get(u);
                    count = ratings.getCount();
                    if (count == 0) continue;
                    this.userFactors.setRow(u, this.updateParameters(this.itemFactors, ratings, userHyperParameters));
                }
                for (int i = 0; i < this.numItems; ++i) {
                    ratings = (SparseVector)itemTrainVectors.get(i);
                    count = ratings.getCount();
                    if (count == 0) continue;
                    this.itemFactors.setRow(i, this.updateParameters(this.userFactors, ratings, itemHyperParameters));
                }
            }
            if (iter == 1) {
                for (MatrixEntry me : this.testMatrix) {
                    int u = me.row();
                    int i = me.column();
                    this.predictMatrix.set(u, i, 0.0);
                }
            }
            if (iter <= (startnum = 0)) continue;
            for (MatrixEntry me : this.testMatrix) {
                int userIdx = me.row();
                int itemIdx = me.column();
                double predictValue = (this.predictMatrix.get(userIdx, itemIdx) * (double)(iter - 1 - startnum) + this.globalMean + DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx)) / (double)(iter - startnum);
                this.predictMatrix.set(userIdx, itemIdx, predictValue);
            }
        }
    }

    protected HyperParameters samplingHyperParameters(HyperParameters hyperParameters, DenseMatrix factors, DenseVector normalMu0, double normalBeta0, DenseMatrix WishartScale0, double WishartNu0) throws LibrecException {
        DenseMatrix normalVariance;
        int numRows = factors.numRows();
        int numColumns = factors.numColumns();
        DenseVector mean = new DenseVector(this.numFactors);
        for (int i = 0; i < numColumns; ++i) {
            mean.set(i, factors.columnMean(i));
        }
        DenseMatrix populationVariance = factors.cov();
        double betaPost = normalBeta0 + (double)numRows;
        double nuPost = WishartNu0 + 1.0;
        DenseVector muPost = normalMu0.scale(normalBeta0).add(mean.scale(numRows)).scale(1.0 / betaPost);
        DenseMatrix WishartScalePost = WishartScale0.add(populationVariance.scale(numRows));
        DenseVector muError = normalMu0.minus(mean);
        WishartScalePost = WishartScalePost.add(muError.outer(muError).scale(normalBeta0 * (double)numRows / betaPost));
        WishartScalePost = WishartScalePost.inv();
        DenseMatrix variance = Randoms.wishart(WishartScalePost = WishartScalePost.add(WishartScalePost.transpose()).scale(0.5), numRows + numColumns);
        if (variance != null) {
            hyperParameters.variance = variance;
        }
        if ((normalVariance = hyperParameters.variance.scale(normalBeta0).inv().cholesky()) != null) {
            normalVariance = normalVariance.transpose();
            DenseVector normalRdn = new DenseVector(numColumns);
            for (int f = 0; f < this.numFactors; ++f) {
                normalRdn.set(f, Randoms.gaussian(0.0, 1.0));
            }
            hyperParameters.mu = normalVariance.mult(normalRdn).add(muPost);
        }
        return hyperParameters;
    }

    protected DenseVector updateParameters(DenseMatrix factors, SparseVector ratings, HyperParameters hyperParameters) throws LibrecException {
        int num = ratings.getCount();
        DenseMatrix XX = new DenseMatrix(num, this.numFactors);
        DenseVector ratingsReg = new DenseVector(num);
        int index = 0;
        for (int j : ratings.getIndex()) {
            ratingsReg.set(index, ratings.get(j) - this.globalMean);
            XX.setRow(index, factors.row(j));
            ++index;
        }
        DenseMatrix covar = hyperParameters.variance.add(XX.transpose().mult(XX).scale(this.ratingSigma)).inv();
        DenseVector mu = XX.transpose().mult(ratingsReg).scale(this.ratingSigma);
        mu.addEqual(hyperParameters.variance.mult(hyperParameters.mu));
        mu = covar.mult(mu);
        DenseVector factorVector = new DenseVector(this.numFactors);
        DenseMatrix lam = covar.cholesky();
        if (lam != null) {
            lam = lam.transpose();
            for (int f = 0; f < this.numFactors; ++f) {
                factorVector.set(f, Randoms.gaussian(0.0, 1.0));
            }
            DenseVector w1_P1_u = lam.mult(factorVector).add(mu);
            for (int f = 0; f < this.numFactors; ++f) {
                factorVector.set(f, w1_P1_u.get(f));
            }
        }
        return factorVector;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        return this.predictMatrix.get(userIdx, itemIdx);
    }

    public class HyperParameters {
        public DenseVector mu;
        public DenseMatrix variance;

        HyperParameters(DenseVector _mu, DenseMatrix _variance) {
            this.mu = _mu;
            this.variance = _variance;
        }
    }
}

