/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.collect.BiMap;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SparseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.SymmMatrix;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.AbstractRecommender;
import net.librec.util.Lists;

@ModelData(value={"isRanking", "slim", "coefficientMatrix", "trainMatrix", "similarityMatrix", "knn"})
public class BLNSLIMRecommender
extends AbstractRecommender {
    protected int numIterations;
    private DenseMatrix coefficientMatrix;
    private Set<Integer>[] itemNNs;
    private float regL1Norm;
    private float regL2Norm;
    private float lambda3;
    private int[] groupMembershipVector;
    protected SparseMatrix itemFeatureMatrix;
    private String protectedAttribute;
    BiMap<String, Integer> featureIdMapping;
    private double weights;
    protected static int knn;
    private SymmMatrix similarityMatrix;
    private Set<Integer> allItems;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        knn = this.conf.getInt("rec.neighbors.knn.number", 50);
        this.numIterations = this.conf.getInt("rec.iterator.maximum");
        this.regL1Norm = this.conf.getFloat("rec.slim.regularization.l1", Float.valueOf(1.0f)).floatValue();
        this.regL2Norm = this.conf.getFloat("rec.slim.regularization.l2", Float.valueOf(1.0f)).floatValue();
        this.lambda3 = this.conf.getFloat("rec.bnslim.regularization.l3", Float.valueOf(1.0f)).floatValue();
        this.protectedAttribute = this.conf.get("data.protected.feature");
        System.out.println("***");
        System.out.println("l1 reg: " + this.regL1Norm);
        System.out.println("l2 reg: " + this.regL2Norm);
        System.out.println("balance controller l3: " + this.lambda3);
        System.out.println("***");
        this.coefficientMatrix = new DenseMatrix(this.numItems, this.numItems);
        this.coefficientMatrix.init();
        this.similarityMatrix = this.context.getSimilarity().getSimilarityMatrix();
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            this.coefficientMatrix.set(itemIdx, itemIdx, 0.0);
        }
        this.createItemNNs();
        this.itemFeatureMatrix = this.getDataModel().getFeatureAppender().getItemFeatures();
        this.featureIdMapping = this.getDataModel().getFeatureAppender().getItemFeatureMap();
        BiMap itemMappingInverse = this.itemMappingData.inverse();
        this.groupMembershipVector = new int[this.numItems];
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            int itemId = Integer.parseInt((String)itemMappingInverse.get(itemIdx));
            int itemMembership = -1;
            if (this.itemFeatureMatrix.getColumnsSet(itemIdx).size() > 0 && this.itemFeatureMatrix.get(itemIdx, (Integer)this.featureIdMapping.get(this.protectedAttribute)) == 1.0) {
                itemMembership = 1;
            }
            this.groupMembershipVector[itemIdx] = itemMembership;
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            this.weights = 0.0;
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                Set<Integer> nearestNeighborCollection = knn > 0 ? this.itemNNs[itemIdx] : this.allItems;
                double[] userRatingEntries = new double[this.numUsers];
                Iterator<VectorEntry> userItr = this.trainMatrix.rowIterator(itemIdx);
                while (userItr.hasNext()) {
                    VectorEntry userRatingEntry = userItr.next();
                    userRatingEntries[userRatingEntry.index()] = userRatingEntry.get();
                }
                for (Integer nearestNeighborItemIdx : nearestNeighborCollection) {
                    if (nearestNeighborItemIdx == itemIdx) continue;
                    double gradSum = 0.0;
                    double rateSum = 0.0;
                    double errors = 0.0;
                    double itemBalanceSumSqr = 0.0;
                    double itemBalanceSum = 0.0;
                    Iterator<VectorEntry> nnUserRatingItr = this.trainMatrix.rowIterator(nearestNeighborItemIdx);
                    if (!nnUserRatingItr.hasNext()) continue;
                    int nnCount = 0;
                    while (nnUserRatingItr.hasNext()) {
                        VectorEntry nnUserVectorEntry = nnUserRatingItr.next();
                        int nnUserIdx = nnUserVectorEntry.index();
                        double nnRating = nnUserVectorEntry.get();
                        double rating = userRatingEntries[nnUserIdx];
                        double error = rating - this.predict(nnUserIdx, itemIdx, nearestNeighborItemIdx);
                        double itemBalance = this.balancePredictor(nnUserIdx, itemIdx, nearestNeighborItemIdx);
                        itemBalanceSumSqr += itemBalance * itemBalance;
                        itemBalanceSum += itemBalance;
                        gradSum += nnRating * error;
                        rateSum += nnRating * nnRating;
                        errors += error * error;
                        ++nnCount;
                    }
                    itemBalanceSumSqr /= (double)nnCount;
                    itemBalanceSum /= (double)nnCount;
                    gradSum /= (double)nnCount;
                    rateSum /= (double)nnCount;
                    double coefficient = this.coefficientMatrix.get(nearestNeighborItemIdx, itemIdx);
                    Integer itemMembership = this.groupMembershipVector[itemIdx];
                    this.loss += 0.5 * (errors /= (double)nnCount) + 0.5 * (double)this.regL2Norm * coefficient * coefficient + (double)this.regL1Norm * coefficient + 0.5 * (double)this.lambda3 * itemBalanceSumSqr;
                    this.weights += itemBalanceSum;
                    double beta = gradSum + (double)(this.lambda3 * (float)itemMembership.intValue()) * itemBalanceSum;
                    double update = 0.0;
                    if ((double)this.regL1Norm < Math.abs(beta)) {
                        update = beta > 0.0 ? (beta - (double)this.regL1Norm) / ((double)this.regL2Norm + rateSum + (double)this.lambda3) : (beta + (double)this.regL1Norm) / ((double)this.regL2Norm + rateSum + (double)this.lambda3);
                    }
                    this.coefficientMatrix.set(nearestNeighborItemIdx, itemIdx, update);
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
        }
    }

    protected double predict(int userIdx, int itemIdx, int excludedItemIdx) {
        double predictRating = 0.0;
        Iterator<VectorEntry> itemEntryIterator = this.trainMatrix.colIterator(userIdx);
        while (itemEntryIterator.hasNext()) {
            VectorEntry itemEntry = itemEntryIterator.next();
            int nearestNeighborItemIdx = itemEntry.index();
            double nearestNeighborPredictRating = itemEntry.get();
            if (!this.itemNNs[itemIdx].contains(nearestNeighborItemIdx) || nearestNeighborItemIdx == excludedItemIdx) continue;
            double coeff = this.coefficientMatrix.get(nearestNeighborItemIdx, itemIdx);
            predictRating += nearestNeighborPredictRating * coeff;
        }
        return predictRating;
    }

    protected double balancePredictor(int userIdx, int itemIdx, int excludedItemIdx) {
        double predictBalance = 0.0;
        Iterator<VectorEntry> itemEntryIterator = this.trainMatrix.colIterator(userIdx);
        while (itemEntryIterator.hasNext()) {
            VectorEntry itemEntry = itemEntryIterator.next();
            int nearestNeighborItemIdx = itemEntry.index();
            if (!this.itemNNs[itemIdx].contains(nearestNeighborItemIdx) || nearestNeighborItemIdx == excludedItemIdx) continue;
            predictBalance += (double)this.groupMembershipVector[nearestNeighborItemIdx] * this.coefficientMatrix.get(nearestNeighborItemIdx, itemIdx);
        }
        return predictBalance;
    }

    @Override
    protected boolean isConverged(int iter) {
        double delta_loss = this.lastLoss - this.loss;
        this.lastLoss = this.loss;
        if (verbose) {
            String recName = this.getClass().getSimpleName().toString();
            String info = recName + " iter " + iter + ": loss = " + this.loss + ", delta_loss = " + delta_loss;
            this.LOG.info(info);
            this.LOG.info("The item balance sum is " + this.weights + "\n");
        }
        return iter > 1 ? delta_loss < 1.0E-5 : false;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        if (null == this.itemNNs || this.itemNNs.length <= 0) {
            this.createItemNNs();
        }
        return this.predict(userIdx, itemIdx, -1);
    }

    public void createItemNNs() {
        this.itemNNs = new HashSet[this.numItems];
        if (knn > 0) {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                SparseVector similarityVector = this.similarityMatrix.row(itemIdx);
                if (knn < similarityVector.size()) {
                    List tempItemSimList = new ArrayList(similarityVector.size() + 1);
                    for (VectorEntry simVectorEntry : similarityVector) {
                        tempItemSimList.add(new AbstractMap.SimpleImmutableEntry<Integer, Double>(simVectorEntry.index(), simVectorEntry.get()));
                    }
                    tempItemSimList = Lists.sortListTopK(tempItemSimList, true, knn);
                    this.itemNNs[itemIdx] = new HashSet<Integer>((int)((double)tempItemSimList.size() / 0.5));
                    for (Map.Entry tempItemSimEntry : tempItemSimList) {
                        this.itemNNs[itemIdx].add((Integer)tempItemSimEntry.getKey());
                    }
                    continue;
                }
                this.itemNNs[itemIdx] = similarityVector.getIndexSet();
            }
        } else {
            this.allItems = new HashSet<Integer>(this.trainMatrix.columns());
        }
    }
}

