/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.FactorizationMachineRecommender;

@ModelData(value={"isRanking", "fmsgd", "W", "V", "W0", "k"})
public class FMSGDRecommender
extends FactorizationMachineRecommender {
    private double learnRate;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.learnRate = this.conf.getDouble("rec.iterator.learnRate");
    }

    @Override
    protected void trainModel() throws LibrecException {
        if (!this.isRanking) {
            this.buildRatingModel();
        }
    }

    private void buildRatingModel() throws LibrecException {
        for (int iter = 0; iter < this.numIterations; ++iter) {
            this.lastLoss = this.loss;
            this.loss = 0.0;
            int userDimension = this.trainTensor.getUserDimension();
            int itemDimension = this.trainTensor.getItemDimension();
            for (TensorEntry me : this.trainTensor) {
                int[] entryKeys = me.keys();
                SparseVector vector = this.tenserKeysToFeatureVector(entryKeys);
                double rate = me.get();
                double pred = this.predict(entryKeys[userDimension], entryKeys[itemDimension], vector);
                double err = pred - rate;
                this.loss += err * err;
                double gradLoss = err;
                this.loss += (double)this.regW0 * this.w0 * this.w0;
                double hW0 = 1.0;
                double gradW0 = gradLoss * hW0 + (double)this.regW0 * this.w0;
                this.w0 += -this.learnRate * gradW0;
                for (VectorEntry ve : vector) {
                    int l = ve.index();
                    double oldWl = this.W.get(l);
                    double hWl = ve.get();
                    double gradWl = gradLoss * hWl + (double)this.regW * oldWl;
                    this.W.add(l, -this.learnRate * gradWl);
                    this.loss += (double)this.regW * oldWl * oldWl;
                    for (int f = 0; f < this.k; ++f) {
                        double oldVlf = this.V.get(l, f);
                        double hVlf = 0.0;
                        double xl = ve.get();
                        for (VectorEntry ve2 : vector) {
                            int j = ve2.index();
                            if (j == l) continue;
                            hVlf += xl * this.V.get(j, f) * ve2.get();
                        }
                        double gradVlf = gradLoss * hVlf + (double)this.regF * oldVlf;
                        this.V.add(l, f, -this.learnRate * gradVlf);
                        this.loss += (double)this.regF * oldVlf * oldVlf;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
        }
    }

    @Override
    @Deprecated
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return 0.0;
    }
}

