/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.MatrixFactorizationRecommender;

public class RFRecRecommender
extends MatrixFactorizationRecommender {
    private DenseVector userAverages;
    private DenseVector itemAverages;
    private DenseMatrix userRatingFrequencies;
    private DenseMatrix itemRatingFrequencies;
    private DenseVector userWeights;
    private DenseVector itemWeights;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.userAverages = new DenseVector(this.numUsers);
        this.itemAverages = new DenseVector(this.numItems);
        this.userWeights = new DenseVector(this.numUsers);
        this.itemWeights = new DenseVector(this.numItems);
        for (int u = 0; u < this.numUsers; ++u) {
            this.userAverages.set(u, this.trainMatrix.row(u).mean());
            this.userWeights.set(u, 0.6 + Randoms.uniform() * 0.01);
        }
        for (int j = 0; j < this.numItems; ++j) {
            this.itemAverages.set(j, this.trainMatrix.column(j).mean());
            this.itemWeights.set(j, 0.4 + Randoms.uniform() * 0.01);
        }
        this.userRatingFrequencies = new DenseMatrix(this.numUsers, this.numRates);
        this.itemRatingFrequencies = new DenseMatrix(this.numItems, this.numRates);
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userIdx = matrixEntry.row();
            int itemIdx = matrixEntry.column();
            int realRating = (int)matrixEntry.get();
            this.userRatingFrequencies.add(userIdx, realRating, 1.0);
            this.itemRatingFrequencies.add(itemIdx, realRating, 1.0);
        }
        this.userWeights = new DenseVector(this.numUsers);
        this.itemWeights = new DenseVector(this.numItems);
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double realRating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx);
                double error = realRating - predictRating;
                double userWeight = this.userWeights.get(userIdx) + (double)this.learnRate * (error - (double)this.regUser * this.userWeights.get(userIdx));
                this.userWeights.set(userIdx, userWeight);
                double itemWeight = this.itemWeights.get(itemIdx) + (double)this.learnRate * (error - (double)this.regItem * this.itemWeights.get(itemIdx));
                this.itemWeights.set(itemIdx, itemWeight);
            }
        }
    }

    private int isAvgRating(double avg, int rating) {
        return Math.round(avg) == (long)rating ? 1 : 0;
    }

    @Override
    public double predict(int userIdx, int itemIdx) {
        double estimate = this.globalMean;
        float enumeratorUser = 0.0f;
        float denominatorUser = 0.0f;
        float enumeratorItem = 0.0f;
        float denominatorItem = 0.0f;
        if (this.userRatingFrequencies.row(userIdx).sum() > 0.0 && this.itemRatingFrequencies.row(itemIdx).sum() > 0.0 && this.userAverages.get(userIdx) > 0.0 && this.itemAverages.get(itemIdx) > 0.0) {
            for (int r = 0; r < ratingScale.size(); ++r) {
                int ratingValue = (int)Math.round((Double)ratingScale.get(r));
                int tmpUser = 0;
                double frequencyInt = this.userRatingFrequencies.get(userIdx, ratingValue);
                int frequency = (int)frequencyInt;
                tmpUser = frequency + 1 + this.isAvgRating(this.userAverages.get(userIdx), ratingValue);
                enumeratorUser += (float)(tmpUser * ratingValue);
                denominatorUser += (float)tmpUser;
                int tmpItem = 0;
                frequencyInt = this.itemRatingFrequencies.get(itemIdx, ratingValue);
                frequency = (int)frequencyInt;
                tmpItem = frequency + 1 + this.isAvgRating(this.itemAverages.get(itemIdx), ratingValue);
                enumeratorItem += (float)(tmpItem * ratingValue);
                denominatorItem += (float)tmpItem;
            }
            double w_u = this.userWeights.get(userIdx);
            double w_i = this.itemWeights.get(itemIdx);
            float pred_ui_user = enumeratorUser / denominatorUser;
            float pred_ui_item = enumeratorItem / denominatorItem;
            estimate = (float)w_u * pred_ui_user + (float)w_i * pred_ui_item;
        } else {
            if (this.userRatingFrequencies.row(userIdx).sum() == 0.0 || this.userAverages.get(userIdx) == 0.0) {
                double iavg = this.itemAverages.get(itemIdx);
                if (iavg != 0.0) {
                    return iavg;
                }
                return this.globalMean;
            }
            if (this.itemRatingFrequencies.row(itemIdx).sum() == 0.0 || this.itemAverages.get(itemIdx) == 0.0) {
                double uavg = this.userAverages.get(userIdx);
                if (uavg != 0.0) {
                    return uavg;
                }
                return this.globalMean;
            }
        }
        return estimate;
    }
}

