/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.cache.LoadingCache;
import com.google.common.collect.Table;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SparseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "fismrmse", "P", "Q", "itemBiases", "userBiases"})
public class FISMrmseRecommender
extends MatrixFactorizationRecommender {
    protected static String cacheSpec;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    private int nnz;
    private float rho;
    private float alpha;
    private float beta;
    private float itemBiasReg;
    private float userBiasReg;
    private double lRate;
    private DenseVector itemBiases;
    private DenseVector userBiases;
    private DenseMatrix P;
    private DenseMatrix Q;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.P = new DenseMatrix(this.numItems, this.numFactors);
        this.Q = new DenseMatrix(this.numItems, this.numFactors);
        this.P.init(0.0, 0.01);
        this.Q.init(0.0, 0.01);
        this.userBiases = new DenseVector(this.numUsers);
        this.itemBiases = new DenseVector(this.numItems);
        this.userBiases.init(0.0, 0.01);
        this.itemBiases.init(0.0, 0.01);
        this.nnz = this.trainMatrix.size();
        this.rho = this.conf.getFloat("rec.recommender.rho").floatValue();
        this.alpha = this.conf.getFloat("rec.recommender.alpha", Float.valueOf(0.5f)).floatValue();
        this.beta = this.conf.getFloat("rec.recommender.beta", Float.valueOf(0.6f)).floatValue();
        this.itemBiasReg = this.conf.getFloat("rec.recommender.itemBiasReg", Float.valueOf(0.1f)).floatValue();
        this.userBiasReg = this.conf.getFloat("rec.recommender.userBiasReg", Float.valueOf(0.1f)).floatValue();
        this.lRate = this.conf.getDouble("rec.iteration.learnrate", 1.0E-4);
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
    }

    @Override
    protected void trainModel() throws LibrecException {
        int sampleSize = (int)(this.rho * (float)this.nnz);
        int totalSize = this.numUsers * this.numItems;
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            Table<Integer, Integer, Double> R = this.trainMatrix.getDataTable();
            List<Integer> indices = null;
            try {
                indices = Randoms.randInts(sampleSize, 0, totalSize - this.nnz);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            int index = 0;
            int count = 0;
            boolean isDone = false;
            for (int u = 0; u < this.numUsers; ++u) {
                for (int j = 0; j < this.numItems; ++j) {
                    double ruj = this.trainMatrix.get(u, j);
                    if (ruj != 0.0 || count++ != indices.get(index)) continue;
                    R.put(u, j, 0.0);
                    if (++index < indices.size()) continue;
                    isDone = true;
                    break;
                }
                if (isDone) break;
            }
            for (Table.Cell<Integer, Integer, Double> cell : R.cellSet()) {
                int u = cell.getRowKey();
                int i = cell.getColumnKey();
                double rui = cell.getValue();
                SparseVector Ru = this.trainMatrix.row(u);
                int n_u = Ru.size() - 1;
                if (n_u == 0 || n_u == -1) {
                    n_u = 1;
                }
                DenseVector X = new DenseVector(this.numFactors);
                for (int j : Ru.getIndex()) {
                    if (i == j) continue;
                    X = X.add(this.P.row(j));
                }
                X = X.scale(Math.pow(n_u, -this.alpha));
                double bi = this.itemBiases.get(i);
                double bu = this.userBiases.get(u);
                double pui = bu + bi + this.Q.row(i).inner(X);
                double eui = rui - pui;
                this.loss += eui * eui;
                this.itemBiases.add(i, this.lRate * (eui - (double)this.itemBiasReg * bi));
                this.loss += (double)this.itemBiasReg * bi * bi;
                this.userBiases.add(u, this.lRate * (eui - (double)this.userBiasReg * bu));
                this.loss += (double)this.itemBiasReg * bu * bu;
                DenseVector deltaq = X.scale(eui).minus(this.Q.row(i).scale(this.beta));
                this.loss += (double)this.beta * this.Q.row(i).inner(this.Q.row(i));
                this.Q.setRow(i, this.Q.row(i).add(deltaq.scale(this.lRate)));
                for (int j : Ru.getIndex()) {
                    if (i == j) continue;
                    DenseVector deltap = this.Q.row(i).scale(eui * Math.pow(n_u, -this.alpha)).minus(this.P.row(j).scale(this.beta));
                    this.loss += (double)this.beta * this.P.row(j).inner(this.P.row(j));
                    this.P.setRow(j, this.P.row(j).add(deltap.scale(this.lRate)));
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int u, int j) throws LibrecException {
        double pred = this.userBiases.get(u) + this.itemBiases.get(j);
        List<Integer> ratedItems = null;
        try {
            ratedItems = this.userItemsCache.get(u);
        }
        catch (ExecutionException e) {
            e.printStackTrace();
        }
        double sum = 0.0;
        int count = 0;
        for (int i : ratedItems) {
            if (i == j) continue;
            sum += DenseMatrix.rowMult(this.P, i, this.Q, j);
            ++count;
        }
        double wu = count - 1 > 0 ? Math.pow(count - 1, -this.alpha) : 0.0;
        return pred + wu * sum;
    }
}

