/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "eals", "userFactors", "itemFactors", "trainMatrix"})
public class EALSRecommender
extends MatrixFactorizationRecommender {
    protected float weightCoefficient;
    private float ratio;
    private float overallWeight;
    private int WRMFJudge;
    private double[] confidences;
    private SparseMatrix weights;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.weightCoefficient = this.conf.getFloat("rec.wrmf.weight.coefficient", Float.valueOf(4.0f)).floatValue();
        this.ratio = this.conf.getFloat("rec.eals.ratio", Float.valueOf(0.4f)).floatValue();
        this.overallWeight = this.conf.getFloat("rec.eals.overall", Float.valueOf(128.0f)).floatValue();
        this.WRMFJudge = this.conf.getInt("rec.eals.wrmf.judge", 1);
        this.confidences = new double[this.numItems];
        this.weights = new SparseMatrix(this.trainMatrix);
        this.initConfidencesAndWeights();
    }

    private void initConfidencesAndWeights() {
        if (this.WRMFJudge == 0 || this.WRMFJudge == 2) {
            int itemIdx;
            double sumPopularity = 0.0;
            for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                double alphaPopularity = Math.pow((double)this.trainMatrix.columnSize(itemIdx) * 1.0 / (double)this.numRates, this.ratio);
                this.confidences[itemIdx] = (double)this.overallWeight * alphaPopularity;
                sumPopularity += alphaPopularity;
            }
            for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                this.confidences[itemIdx] = this.confidences[itemIdx] / sumPopularity;
            }
        } else {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                this.confidences[itemIdx] = 1.0;
            }
        }
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userIdx = matrixEntry.row();
            int itemIdx = matrixEntry.column();
            if (this.WRMFJudge == 1 || this.WRMFJudge == 2) {
                this.weights.set(userIdx, itemIdx, 1.0 + (double)this.weightCoefficient * matrixEntry.get());
                continue;
            }
            this.weights.set(userIdx, itemIdx, 1.0);
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        List<List<Integer>> userItemsList = this.getUserItemsList(this.trainMatrix);
        List<List<Integer>> itemUsersList = this.getItemUsersList(this.trainMatrix);
        double[] usersPredictions = new double[this.numUsers];
        double[] itemsPredictions = new double[this.numItems];
        double[] usersWeights = new double[this.numUsers];
        double[] itemsWeights = new double[this.numItems];
        DenseMatrix itemFactorsCache = new DenseMatrix(this.numFactors, this.numFactors);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            int factorCacheIdx;
            for (int factorIdx1 = 0; factorIdx1 < this.numFactors; ++factorIdx1) {
                for (int factorIdx2 = 0; factorIdx2 <= factorIdx1; ++factorIdx2) {
                    double value = 0.0;
                    for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                        value += this.confidences[itemIdx] * this.itemFactors.get(itemIdx, factorIdx1) * this.itemFactors.get(itemIdx, factorIdx2);
                    }
                    itemFactorsCache.set(factorIdx1, factorIdx2, value);
                    itemFactorsCache.set(factorIdx2, factorIdx1, value);
                }
            }
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                for (int itemIdx : userItemsList.get(userIdx)) {
                    itemsPredictions[itemIdx] = DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx);
                    itemsWeights[itemIdx] = this.weights.get(userIdx, itemIdx);
                }
                for (factorCacheIdx = 0; factorCacheIdx < this.numFactors; ++factorCacheIdx) {
                    int itemIdx;
                    double numer = 0.0;
                    double denom = (double)this.regUser + itemFactorsCache.get(factorCacheIdx, factorCacheIdx);
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        if (factorCacheIdx == factorIdx) continue;
                        numer -= this.userFactors.get(userIdx, factorIdx) * itemFactorsCache.get(factorCacheIdx, factorIdx);
                    }
                    Iterator<Integer> factorIdx = userItemsList.get(userIdx).iterator();
                    while (factorIdx.hasNext()) {
                        int n = itemIdx = factorIdx.next().intValue();
                        itemsPredictions[n] = itemsPredictions[n] - this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                        numer += (itemsWeights[itemIdx] - (itemsWeights[itemIdx] - this.confidences[itemIdx]) * itemsPredictions[itemIdx]) * this.itemFactors.get(itemIdx, factorCacheIdx);
                        denom += (itemsWeights[itemIdx] - this.confidences[itemIdx]) * this.itemFactors.get(itemIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                    }
                    this.userFactors.set(userIdx, factorCacheIdx, numer / denom);
                    factorIdx = userItemsList.get(userIdx).iterator();
                    while (factorIdx.hasNext()) {
                        int n = itemIdx = factorIdx.next().intValue();
                        itemsPredictions[n] = itemsPredictions[n] + this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                    }
                }
            }
            DenseMatrix userFactorsCache = this.userFactors.transpose().mult(this.userFactors);
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                for (int userIdx : itemUsersList.get(itemIdx)) {
                    usersPredictions[userIdx] = DenseMatrix.rowMult(this.userFactors, userIdx, this.itemFactors, itemIdx);
                    usersWeights[userIdx] = this.weights.get(userIdx, itemIdx);
                }
                for (factorCacheIdx = 0; factorCacheIdx < this.numFactors; ++factorCacheIdx) {
                    int userIdx;
                    double numer = 0.0;
                    double denom = this.confidences[itemIdx] * userFactorsCache.get(factorCacheIdx, factorCacheIdx) + (double)this.regItem;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        if (factorCacheIdx == factorIdx) continue;
                        numer -= this.itemFactors.get(itemIdx, factorIdx) * userFactorsCache.get(factorIdx, factorCacheIdx);
                    }
                    numer *= this.confidences[itemIdx];
                    Iterator<Integer> iterator = itemUsersList.get(itemIdx).iterator();
                    while (iterator.hasNext()) {
                        int n = userIdx = iterator.next().intValue();
                        usersPredictions[n] = usersPredictions[n] - this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                        numer += (usersWeights[userIdx] - (usersWeights[userIdx] - this.confidences[itemIdx]) * usersPredictions[userIdx]) * this.userFactors.get(userIdx, factorCacheIdx);
                        denom += (usersWeights[userIdx] - this.confidences[itemIdx]) * this.userFactors.get(userIdx, factorCacheIdx) * this.userFactors.get(userIdx, factorCacheIdx);
                    }
                    this.itemFactors.set(itemIdx, factorCacheIdx, numer / denom);
                    iterator = itemUsersList.get(itemIdx).iterator();
                    while (iterator.hasNext()) {
                        int n = userIdx = iterator.next().intValue();
                        usersPredictions[n] = usersPredictions[n] + this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                    }
                }
            }
        }
    }

    private List<List<Integer>> getUserItemsList(SparseMatrix sparseMatrix) {
        ArrayList<List<Integer>> userItemsList = new ArrayList<List<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            userItemsList.add(sparseMatrix.getColumns(userIdx));
        }
        return userItemsList;
    }

    private List<List<Integer>> getItemUsersList(SparseMatrix sparseMatrix) {
        ArrayList<List<Integer>> itemUsersList = new ArrayList<List<Integer>>();
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            itemUsersList.add(sparseMatrix.getRows(itemIdx));
        }
        return itemUsersList;
    }
}

