/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.collect.BiMap;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SparseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.SymmMatrix;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.AbstractRecommender;
import net.librec.util.Lists;

@ModelData(value={"isRanking", "slim", "coefficientMatrix", "trainMatrix", "similarityMatrix", "knn"})
public class UBLNSLIMRecommender
extends AbstractRecommender {
    protected int numIterations;
    private DenseMatrix coefficientMatrix;
    private Set<Integer>[] userNNs;
    private float regL1Norm;
    private float regL2Norm;
    private float lambda3;
    private int[] groupMembershipVector;
    protected SparseMatrix userFeatureMatrix;
    private String protectedAttribute;
    BiMap<String, Integer> featureIdMapping;
    private double balance;
    private double weights;
    protected static int knn;
    private SymmMatrix similarityMatrix;
    private Set<Integer> allUsers;
    private float minSimThresh;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        knn = this.conf.getInt("rec.neighbors.knn.number", 50);
        this.numIterations = this.conf.getInt("rec.iterator.maximum");
        this.regL1Norm = this.conf.getFloat("rec.slim.regularization.l1", Float.valueOf(1.0f)).floatValue();
        this.regL2Norm = this.conf.getFloat("rec.slim.regularization.l2", Float.valueOf(1.0f)).floatValue();
        this.lambda3 = this.conf.getFloat("rec.bnslim.regularization.l3", Float.valueOf(1.0f)).floatValue();
        this.minSimThresh = this.conf.getFloat("rec.bnslim.minsimilarity", Float.valueOf(-1.0f)).floatValue();
        this.protectedAttribute = this.conf.get("data.protected.feature");
        System.out.println("***");
        System.out.println("l1 reg: " + this.regL1Norm);
        System.out.println("l2 reg: " + this.regL2Norm);
        System.out.println("balance controller l3: " + this.lambda3);
        System.out.println("***");
        this.coefficientMatrix = new DenseMatrix(this.numUsers, this.numUsers);
        this.coefficientMatrix.init();
        this.similarityMatrix = this.context.getSimilarity().getSimilarityMatrix();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            this.coefficientMatrix.set(userIdx, userIdx, 0.0);
        }
        this.createUserNNs();
        this.userFeatureMatrix = this.getDataModel().getFeatureAppender().getUserFeatures();
        this.featureIdMapping = this.getDataModel().getFeatureAppender().getUserFeatureMap();
        BiMap userMappingInverse = this.userMappingData.inverse();
        this.groupMembershipVector = new int[this.numUsers];
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int userId = Integer.parseInt((String)userMappingInverse.get(userIdx));
            int userMembership = -1;
            if (this.userFeatureMatrix.getColumnsSet(userIdx).size() > 0 && this.userFeatureMatrix.get(userIdx, (Integer)this.featureIdMapping.get(this.protectedAttribute)) == 1.0) {
                userMembership = 1;
            }
            this.groupMembershipVector[userIdx] = userMembership;
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            this.weights = 0.0;
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                Set<Integer> nearestNeighborCollection = knn > 0 ? this.userNNs[userIdx] : this.allUsers;
                double[] itemRatingEntries = new double[this.numItems];
                Iterator<VectorEntry> itemItr = this.trainMatrix.colIterator(userIdx);
                while (itemItr.hasNext()) {
                    VectorEntry itemRatingEntry = itemItr.next();
                    itemRatingEntries[itemRatingEntry.index()] = itemRatingEntry.get();
                }
                for (Integer nearestNeighborUserIdx : nearestNeighborCollection) {
                    double sim = this.similarityMatrix.get(nearestNeighborUserIdx, userIdx);
                    if (nearestNeighborUserIdx == userIdx || !(sim > (double)this.minSimThresh)) continue;
                    double gradSum = 0.0;
                    double rateSum = 0.0;
                    double errors = 0.0;
                    double userBalanceSumSqr = 0.0;
                    double userBalanceSum = 0.0;
                    Iterator<VectorEntry> nnItemRatingItr = this.trainMatrix.colIterator(nearestNeighborUserIdx);
                    if (!nnItemRatingItr.hasNext()) continue;
                    int nnCount = 0;
                    while (nnItemRatingItr.hasNext()) {
                        VectorEntry nnItemVectorEntry = nnItemRatingItr.next();
                        int nnItemIdx = nnItemVectorEntry.index();
                        double nnRating = nnItemVectorEntry.get();
                        double rating = itemRatingEntries[nnItemIdx];
                        double error = rating - this.predictBothRatingAndBalance(userIdx, nnItemIdx, nearestNeighborUserIdx);
                        double userBalance = this.balance;
                        userBalanceSumSqr += userBalance * userBalance;
                        userBalanceSum += userBalance;
                        gradSum += nnRating * error;
                        rateSum += nnRating * nnRating;
                        errors += error * error;
                        ++nnCount;
                    }
                    userBalanceSumSqr /= (double)nnCount;
                    userBalanceSum /= (double)nnCount;
                    gradSum /= (double)nnCount;
                    rateSum /= (double)nnCount;
                    double coefficient = this.coefficientMatrix.get(nearestNeighborUserIdx, userIdx);
                    Integer userMembership = this.groupMembershipVector[userIdx];
                    this.loss += 0.5 * (errors /= (double)nnCount) + 0.5 * (double)this.regL2Norm * coefficient * coefficient + (double)this.regL1Norm * coefficient + 0.5 * (double)this.lambda3 * userBalanceSumSqr;
                    this.weights += userBalanceSum;
                    double beta = gradSum + (double)(this.lambda3 * (float)userMembership.intValue()) * userBalanceSum;
                    double update = 0.0;
                    if ((double)this.regL1Norm < Math.abs(beta)) {
                        update = beta > 0.0 ? (beta - (double)this.regL1Norm) / ((double)this.regL2Norm + rateSum + (double)this.lambda3) : (beta + (double)this.regL1Norm) / ((double)this.regL2Norm + rateSum + (double)this.lambda3);
                    }
                    this.coefficientMatrix.set(nearestNeighborUserIdx, userIdx, update);
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
        }
    }

    protected double predictBothRatingAndBalance(int userIdx, int itemIdx, int excludedUserIdx) {
        double predictRating = 0.0;
        this.balance = 0.0;
        Iterator<VectorEntry> userEntryIterator = this.trainMatrix.rowIterator(itemIdx);
        while (userEntryIterator.hasNext()) {
            VectorEntry userEntry = userEntryIterator.next();
            int nearestNeighborUserIdx = userEntry.index();
            double nearestNeighborPredictRating = userEntry.get();
            if (!this.userNNs[userIdx].contains(nearestNeighborUserIdx) || nearestNeighborUserIdx == excludedUserIdx) continue;
            double coeff = this.coefficientMatrix.get(nearestNeighborUserIdx, userIdx);
            predictRating += nearestNeighborPredictRating * coeff;
            this.balance += (double)this.groupMembershipVector[nearestNeighborUserIdx] * coeff;
        }
        return predictRating;
    }

    @Override
    protected boolean isConverged(int iter) {
        double delta_loss = this.lastLoss - this.loss;
        this.lastLoss = this.loss;
        if (verbose) {
            String recName = this.getClass().getSimpleName().toString();
            String info = recName + " iter " + iter + ": loss = " + this.loss + ", delta_loss = " + delta_loss;
            this.LOG.info(info);
            this.LOG.info("The item balance sum is " + this.weights + "\n");
        }
        return iter > 1 ? delta_loss < 1.0E-5 : false;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        if (null == this.userNNs || this.userNNs.length <= 0) {
            this.createUserNNs();
        }
        return this.predictBothRatingAndBalance(userIdx, itemIdx, -1);
    }

    public void createUserNNs() {
        this.userNNs = new HashSet[this.numUsers];
        if (knn > 0) {
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                SparseVector similarityVector = this.similarityMatrix.row(userIdx);
                if (knn < similarityVector.size()) {
                    List tempUserSimList = new ArrayList(similarityVector.size() + 1);
                    for (VectorEntry simVectorEntry : similarityVector) {
                        tempUserSimList.add(new AbstractMap.SimpleImmutableEntry<Integer, Double>(simVectorEntry.index(), simVectorEntry.get()));
                    }
                    tempUserSimList = Lists.sortListTopK(tempUserSimList, true, knn);
                    this.userNNs[userIdx] = new HashSet<Integer>((int)((double)tempUserSimList.size() / 0.5));
                    for (Map.Entry tempUserSimEntry : tempUserSimList) {
                        this.userNNs[userIdx].add((Integer)tempUserSimEntry.getKey());
                    }
                    continue;
                }
                this.userNNs[userIdx] = similarityVector.getIndexSet();
            }
        } else {
            this.allUsers = new HashSet<Integer>(this.trainMatrix.rows());
        }
    }
}

