/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import java.util.HashMap;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.FactorizationMachineRecommender;

@ModelData(value={"isRanking", "ffm", "W", "V", "W0", "k"})
public class FFMRecommender
extends FactorizationMachineRecommender {
    private double learnRate;
    private HashMap<Integer, Integer> map = new HashMap();

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.V = new DenseMatrix(this.p, this.k * this.trainTensor.numDimensions);
        this.V.init(0.0, 0.1);
        int colindex = 0;
        for (int dim = 0; dim < this.trainTensor.numDimensions; ++dim) {
            for (int index = 0; index < this.trainTensor.dimensions[dim]; ++index) {
                this.map.put(colindex + index, dim);
            }
            colindex += this.trainTensor.dimensions[dim];
        }
        this.learnRate = this.conf.getDouble("rec.iterator.learnRate");
    }

    @Override
    protected void trainModel() throws LibrecException {
        if (!this.isRanking) {
            this.buildRatingModel();
        }
    }

    private void buildRatingModel() throws LibrecException {
        for (int iter = 0; iter < this.numIterations; ++iter) {
            this.loss = 0.0;
            int userDimension = this.trainTensor.getUserDimension();
            int itemDimension = this.trainTensor.getItemDimension();
            for (TensorEntry me : this.trainTensor) {
                int[] entryKeys = me.keys();
                SparseVector x = this.tenserKeysToFeatureVector(entryKeys);
                double rate = me.get();
                double pred = this.predict(entryKeys[userDimension], entryKeys[itemDimension], x);
                double err = pred - rate;
                this.loss += err * err;
                double gradLoss = err;
                this.loss += (double)this.regW0 * this.w0 * this.w0;
                double hW0 = 1.0;
                double gradW0 = gradLoss * hW0 + (double)this.regW0 * this.w0;
                this.w0 += -this.learnRate * gradW0;
                for (int l = 0; l < this.p; ++l) {
                    if (!x.contains(l)) continue;
                    double oldWl = this.W.get(l);
                    double hWl = x.get(l);
                    double gradWl = gradLoss * hWl + (double)this.regW * oldWl;
                    this.W.add(l, -this.learnRate * gradWl);
                    this.loss += (double)this.regW * oldWl * oldWl;
                    for (int f = 0; f < this.k; ++f) {
                        double oldVlf = this.V.get(l, this.map.get(l) + f);
                        double hVlf = 0.0;
                        double xl = x.get(l);
                        for (int j = 0; j < this.p; ++j) {
                            if (j == l || !x.contains(j)) continue;
                            hVlf += xl * this.V.get(j, this.map.get(l) + f) * x.get(j);
                        }
                        double gradVlf = gradLoss * hVlf + (double)this.regF * oldVlf;
                        this.V.add(l, this.map.get(l) + f, -this.learnRate * gradVlf);
                        this.loss += (double)this.regF * oldVlf * oldVlf;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
        }
    }

    @Override
    protected double predict(int userId, int itemId, SparseVector x) throws LibrecException {
        double res = 0.0;
        res += this.w0;
        for (VectorEntry ve : x) {
            double val = ve.get();
            int ind = ve.index();
            res += val * this.W.get(ind);
        }
        for (int f = 0; f < this.k; ++f) {
            double sum = 0.0;
            for (VectorEntry vi : x) {
                for (VectorEntry vj : x) {
                    int j;
                    double xi = vi.get();
                    double xj = vj.get();
                    int i = vi.index();
                    if (i == (j = vj.index())) continue;
                    double vifj = this.V.get(i, this.map.get(j) + f);
                    double vjfi = this.V.get(j, this.map.get(i) + f);
                    sum += vifj * vjfi * xi * xj;
                }
            }
            res += sum;
        }
        return res;
    }

    @Override
    @Deprecated
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return 0.0;
    }
}

