/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender;

import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.recommender.AbstractRecommender;

public abstract class MatrixFactorizationRecommender
extends AbstractRecommender {
    protected float learnRate;
    protected float maxLearnRate;
    protected DenseMatrix userFactors;
    protected DenseMatrix itemFactors;
    protected int numFactors;
    protected int numIterations;
    protected float initMean;
    protected float initStd;
    protected float regUser;
    protected float regItem;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numIterations = this.conf.getInt("rec.iterator.maximum", 100);
        this.learnRate = this.conf.getFloat("rec.iterator.learnrate", Float.valueOf(0.01f)).floatValue();
        this.maxLearnRate = this.conf.getFloat("rec.iterator.learnrate.maximum", Float.valueOf(1000.0f)).floatValue();
        this.regUser = this.conf.getFloat("rec.user.regularization", Float.valueOf(0.01f)).floatValue();
        this.regItem = this.conf.getFloat("rec.item.regularization", Float.valueOf(0.01f)).floatValue();
        this.numFactors = this.conf.getInt("rec.factor.number", 10);
        this.isBoldDriver = this.conf.getBoolean("rec.learnrate.bolddriver", false);
        this.decay = this.conf.getFloat("rec.learnrate.decay", Float.valueOf(1.0f)).floatValue();
        this.userFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.itemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.initMean = 0.0f;
        this.initStd = 0.1f;
        this.userFactors.init(this.initMean, this.initStd);
        this.itemFactors.init(this.initMean, this.initStd);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx);
    }

    protected void updateLRate(int iter) {
        if ((double)this.learnRate < 0.0) {
            return;
        }
        if (this.isBoldDriver && iter > 1) {
            this.learnRate = Math.abs(this.lastLoss) > Math.abs(this.loss) ? this.learnRate * 1.05f : this.learnRate * 0.5f;
        } else if (this.decay > 0.0f && this.decay < 1.0f) {
            this.learnRate *= this.decay;
        }
        if (this.maxLearnRate > 0.0f && this.learnRate > this.maxLearnRate) {
            this.learnRate = this.maxLearnRate;
        }
        this.lastLoss = this.loss;
    }
}

