/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import java.util.List;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.KernelSmoothing;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;
import net.librec.recommender.cf.rating.LLORMAUpdater;

public class LLORMARecommender
extends MatrixFactorizationRecommender {
    private int globalNumFactors;
    private int localNumFactors;
    private int globalNumIterations;
    private int localNumIterations;
    private int numThreads;
    protected double globalRegUser;
    protected double globalRegItem;
    protected double localRegUser;
    protected double localRegItem;
    private double globalLearnRate;
    private double localLearnRate;
    private SparseMatrix predictMatrix;
    private int numLocalModels;
    private DenseMatrix globalUserFactors;
    private DenseMatrix globalItemFactors;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.globalNumFactors = this.conf.getInt("rec.global.factors.num", 20);
        this.localNumFactors = this.numFactors;
        this.globalNumIterations = this.conf.getInt("rec.global.iteration.maximum", 100);
        this.localNumIterations = this.numIterations;
        this.globalRegUser = this.conf.getDouble("rec.global.user.regularization", 0.01);
        this.globalRegItem = this.conf.getDouble("rec.global.item.regularization", 0.01);
        this.localRegUser = this.regUser;
        this.localRegItem = this.regItem;
        this.globalLearnRate = this.conf.getDouble("rec.global.iteration.learnrate", 0.01);
        this.localLearnRate = this.conf.getDouble("rec.iteration.learnrate", 0.01);
        this.numThreads = this.conf.getInt("rec.thread.count", 4);
        this.numLocalModels = this.conf.getInt("rec.model.num", 50);
        this.numThreads = this.numThreads > this.numLocalModels ? this.numLocalModels : this.numThreads;
        this.globalUserFactors = new DenseMatrix(this.numUsers, this.globalNumFactors);
        this.globalItemFactors = new DenseMatrix(this.numItems, this.globalNumFactors);
        this.globalUserFactors.init(this.initMean, this.initStd);
        this.globalItemFactors.init(this.initMean, this.initStd);
        this.buildGlobalModel();
        this.predictMatrix = new SparseMatrix(this.testMatrix);
    }

    private void buildGlobalModel() {
        for (int globalIter = 1; globalIter <= this.globalNumIterations; ++globalIter) {
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double rating = matrixEntry.get();
                double predictRating = DenseMatrix.rowMult(this.globalUserFactors, userIdx, this.globalItemFactors, itemIdx);
                double error = rating - predictRating;
                for (int factorIdx = 0; factorIdx < this.globalNumFactors; ++factorIdx) {
                    double puf = this.globalUserFactors.get(userIdx, factorIdx);
                    double qif = this.globalItemFactors.get(itemIdx, factorIdx);
                    this.globalUserFactors.add(userIdx, factorIdx, this.globalLearnRate * (error * qif - this.globalRegUser * puf));
                    this.globalItemFactors.add(itemIdx, factorIdx, this.globalLearnRate * (error * puf - this.globalRegItem * qif));
                }
            }
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        int completeModelCount = 0;
        LLORMAUpdater[] learners = new LLORMAUpdater[this.numThreads];
        int[] anchorArrayUser = new int[this.numLocalModels];
        int[] anchorArrayItem = new int[this.numLocalModels];
        int modelCount = 0;
        int[] runningThreadList = new int[this.numThreads];
        int runningThreadCount = 0;
        int waitingThreadPointer = 0;
        int nextRunningSlot = 0;
        SparseMatrix cumPredictionMatrix = new SparseMatrix(this.testMatrix);
        SparseMatrix cumWeightMatrix = new SparseMatrix(this.testMatrix);
        for (MatrixEntry matrixEntry : this.testMatrix) {
            int userIdx = matrixEntry.row();
            int itemIdx = matrixEntry.column();
            cumPredictionMatrix.set(userIdx, itemIdx, 0.0);
            cumWeightMatrix.set(userIdx, itemIdx, 0.0);
        }
        while (completeModelCount < this.numLocalModels) {
            int anchorUser = Randoms.uniform(this.numUsers);
            List<Integer> itemList = this.trainMatrix.getColumns(anchorUser);
            if (itemList == null || itemList.size() <= 0) continue;
            if (runningThreadCount < this.numThreads && modelCount < this.numLocalModels) {
                int itemListIdx = Randoms.uniform(itemList.size());
                int anchorItem = itemList.get(itemListIdx);
                anchorArrayUser[modelCount] = anchorUser;
                anchorArrayItem[modelCount] = anchorItem;
                DenseVector userWeights = this.kernelSmoothing(this.numUsers, anchorUser, 203, 0.8, false);
                DenseVector itemWeights = this.kernelSmoothing(this.numItems, anchorItem, 203, 0.8, true);
                learners[nextRunningSlot] = new LLORMAUpdater(modelCount, this.localNumFactors, this.numUsers, this.numItems, anchorUser, anchorItem, this.localLearnRate, this.localRegUser, this.localRegItem, this.localNumIterations, userWeights, itemWeights, this.trainMatrix);
                learners[nextRunningSlot].start();
                runningThreadList[runningThreadCount] = modelCount++;
                ++runningThreadCount;
                ++nextRunningSlot;
                continue;
            }
            if (runningThreadCount <= 0) continue;
            try {
                learners[waitingThreadPointer].join();
            }
            catch (InterruptedException ie) {
                this.LOG.error("Join failed: " + ie);
            }
            int currentModelThreadIdx = waitingThreadPointer;
            int currentModelAnchorIdx = completeModelCount++;
            this.predictMatrix = new SparseMatrix(this.testMatrix);
            for (MatrixEntry matrixEntry : this.testMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double weight = KernelSmoothing.kernelize(this.getUserSimilarity(anchorArrayUser[currentModelAnchorIdx], userIdx), 0.8, 203) * KernelSmoothing.kernelize(this.getItemSimilarity(anchorArrayItem[currentModelAnchorIdx], itemIdx), 0.8, 203);
                double newPrediction = learners[currentModelThreadIdx].getLocalUserFactors().row(userIdx, false).inner(learners[currentModelThreadIdx].getLocalItemFactors().row(itemIdx, false)) * weight;
                cumWeightMatrix.set(userIdx, itemIdx, cumWeightMatrix.get(userIdx, itemIdx) + weight);
                cumPredictionMatrix.set(userIdx, itemIdx, cumPredictionMatrix.get(userIdx, itemIdx) + newPrediction);
                double prediction = cumPredictionMatrix.get(userIdx, itemIdx) / cumWeightMatrix.get(userIdx, itemIdx);
                prediction = Double.isNaN(prediction) || prediction == 0.0 ? this.globalMean : prediction;
                prediction = prediction < this.minRate ? this.minRate : prediction;
                prediction = prediction > this.maxRate ? this.maxRate : prediction;
                this.predictMatrix.set(userIdx, itemIdx, prediction);
            }
            nextRunningSlot = waitingThreadPointer;
            waitingThreadPointer = (waitingThreadPointer + 1) % this.numThreads;
            --runningThreadCount;
        }
    }

    private double getUserSimilarity(int userIdx1, int userIdx2) {
        DenseVector userVector2;
        DenseVector userVector1 = this.globalUserFactors.row(userIdx1);
        double sim = 1.0 - 0.6366197723675814 * Math.acos(userVector1.inner(userVector2 = this.globalUserFactors.row(userIdx2)) / (Math.sqrt(userVector1.inner(userVector1)) * Math.sqrt(userVector2.inner(userVector2))));
        if (Double.isNaN(sim)) {
            sim = 0.0;
        }
        return sim;
    }

    private double getItemSimilarity(int itemIdx1, int itemIdx2) {
        DenseVector itemVector2;
        DenseVector itemVector1 = this.globalItemFactors.row(itemIdx1);
        double sim = 1.0 - 0.6366197723675814 * Math.acos(itemVector1.inner(itemVector2 = this.globalItemFactors.row(itemIdx2)) / (Math.sqrt(itemVector1.inner(itemVector1)) * Math.sqrt(itemVector2.inner(itemVector2))));
        if (Double.isNaN(sim)) {
            sim = 0.0;
        }
        return sim;
    }

    private DenseVector kernelSmoothing(int size, int anchorIdx, int kernelType, double width, boolean isItemFeature) {
        DenseVector newFeatureVector = new DenseVector(size);
        newFeatureVector.set(anchorIdx, 1.0);
        for (int index = 0; index < size; ++index) {
            double sim = isItemFeature ? this.getItemSimilarity(index, anchorIdx) : this.getUserSimilarity(index, anchorIdx);
            newFeatureVector.set(index, KernelSmoothing.kernelize(sim, width, kernelType));
        }
        return newFeatureVector;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        return this.predictMatrix.get(userIdx, itemIdx);
    }
}

