/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import com.google.common.collect.HashBasedTable;
import java.util.Iterator;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SparseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.FactorizationMachineRecommender;

@ModelData(value={"isRanking", "fmals", "W", "V", "W0", "k"})
public class FMALSRecommender
extends FactorizationMachineRecommender {
    private DenseMatrix Q;
    private SparseMatrix trainFeatureMatrix;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.Q = new DenseMatrix(this.n, this.k);
        HashBasedTable<Integer, Integer, Double> trainTable = HashBasedTable.create();
        for (int i = 0; i < this.n; ++i) {
            int[] ratingKeys = this.trainTensor.keys(i);
            int colPrefix = 0;
            for (int j = 0; j < ratingKeys.length; ++j) {
                int indexOfFeatureVector = colPrefix + ratingKeys[j];
                colPrefix += this.trainTensor.dimensions[j];
                trainTable.put(i, indexOfFeatureVector, 1.0);
            }
        }
        this.trainFeatureMatrix = new SparseMatrix(this.n, this.p, trainTable);
    }

    @Override
    protected void trainModel() throws LibrecException {
        double x_val;
        DenseVector errors = new DenseVector(this.n);
        int ind = 0;
        int userDimension = this.trainTensor.getUserDimension();
        int itemDimension = this.trainTensor.getItemDimension();
        for (TensorEntry me : this.trainTensor) {
            int[] entryKeys = me.keys();
            SparseVector x = this.tenserKeysToFeatureVector(entryKeys);
            double rate = me.get();
            double pred = this.predict(entryKeys[userDimension], entryKeys[itemDimension], x);
            double err = rate - pred;
            errors.set(ind, err);
            for (int f = 0; f < this.k; ++f) {
                double sum_q = 0.0;
                for (VectorEntry ve : x) {
                    int l = ve.index();
                    x_val = ve.get();
                    sum_q += this.V.get(l, f) * x_val;
                }
                this.Q.set(ind, f, sum_q);
            }
            ++ind;
        }
        for (int iter = 0; iter < this.numIterations; ++iter) {
            this.lastLoss = this.loss;
            this.loss = 0.0;
            double numerator = 0.0;
            double denominator = 0.0;
            for (int i = 0; i < this.n; ++i) {
                double h_theta = 1.0;
                numerator += this.w0 * h_theta * h_theta + h_theta * errors.get(i);
                denominator += h_theta;
            }
            double newW0 = numerator / (denominator += (double)this.regW0);
            for (int i = 0; i < this.n; ++i) {
                double oldErr = errors.get(i);
                double newErr = oldErr + (this.w0 - newW0);
                errors.set(i, newErr);
                this.loss += oldErr * oldErr;
            }
            this.w0 = newW0;
            this.loss += (double)this.regW0 * this.w0 * this.w0;
            for (int l = 0; l < this.p; ++l) {
                int i;
                double oldWl = this.W.get(l);
                numerator = 0.0;
                denominator = 0.0;
                Iterator<VectorEntry> rowIter = this.trainFeatureMatrix.rowIterator(l);
                while (rowIter.hasNext()) {
                    VectorEntry vectorEntry = rowIter.next();
                    double h_theta = vectorEntry.get();
                    i = vectorEntry.index();
                    numerator += oldWl * h_theta * h_theta + h_theta * errors.get(i);
                    denominator += h_theta * h_theta;
                }
                double newWl = numerator / (denominator += (double)this.regW);
                rowIter = this.trainFeatureMatrix.rowIterator(l);
                while (rowIter.hasNext()) {
                    VectorEntry vectorEntry = rowIter.next();
                    i = vectorEntry.index();
                    double oldErr = errors.get(i);
                    double newErr = oldErr + (oldWl - newWl) * vectorEntry.get();
                    errors.set(i, newErr);
                }
                this.W.set(l, newWl);
                this.loss += (double)this.regW * oldWl * oldWl;
            }
            for (int f = 0; f < this.k; ++f) {
                for (int l = 0; l < this.p; ++l) {
                    double oldVlf = this.V.get(l, f);
                    numerator = 0.0;
                    denominator = 0.0;
                    Iterator<VectorEntry> rowIter = this.trainFeatureMatrix.rowIterator(l);
                    while (rowIter.hasNext()) {
                        VectorEntry vectorEntry = rowIter.next();
                        int i = vectorEntry.index();
                        double x_val2 = vectorEntry.get();
                        double h_theta = x_val2 * (this.Q.get(i, f) - oldVlf * x_val2);
                        numerator += oldVlf * h_theta * h_theta + h_theta * errors.get(i);
                        denominator += h_theta * h_theta;
                    }
                    double newVlf = numerator / (denominator += (double)this.regF);
                    rowIter = this.trainFeatureMatrix.rowIterator(l);
                    while (rowIter.hasNext()) {
                        VectorEntry vectorEntry = rowIter.next();
                        int i = vectorEntry.index();
                        x_val = vectorEntry.get();
                        double oldQif = this.Q.get(i, f);
                        double update = (newVlf - oldVlf) * x_val;
                        double newQif = oldQif + update;
                        double h_theta_old = x_val * (oldQif - oldVlf * x_val);
                        double h_theta_new = x_val * (newQif - newVlf * x_val);
                        double oldErr = errors.get(i);
                        double newErr = oldErr + oldVlf * h_theta_old - newVlf * h_theta_new;
                        errors.set(i, newErr);
                        this.Q.set(i, f, newQif);
                    }
                    this.V.set(l, f, newVlf);
                    this.loss += (double)this.regF * oldVlf * oldVlf;
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
        }
    }

    @Override
    @Deprecated
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return 0.0;
    }
}

