/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import com.google.common.cache.LoadingCache;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.cf.rating.BiasedMFRecommender;

@ModelData(value={"isRating", "timesvd", "userFactors", "itemFactors", "userBiases", "itemBiases", "trainMatrix", "timeMatrix"})
public class TimeSVDRecommender
extends BiasedMFRecommender {
    private static int numDays;
    private DenseVector userMeanDate;
    private float beta;
    private int numBins;
    private DenseMatrix Y;
    private DenseMatrix Bit;
    private Table<Integer, Integer, Double> But;
    private DenseVector Alpha;
    private DenseMatrix Auk;
    private Map<Integer, Table<Integer, Integer, Double>> Pukt;
    private DenseVector Cu;
    private DenseMatrix Cut;
    private static long minTimestamp;
    private static long maxTimestamp;
    protected static String cacheSpec;
    private LoadingCache<Integer, List<Integer>> userItemsCache;
    private static SparseMatrix timeMatrix;
    protected DenseMatrix Q;
    protected DenseMatrix P;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.beta = this.conf.getFloat("rec.learnrate.decay", Float.valueOf(0.015f)).floatValue();
        this.numBins = this.conf.getInt("rec.numBins", 6);
        timeMatrix = (SparseMatrix)this.getDataModel().getDatetimeDataSet();
        this.getMaxAndMinTimeStamp();
        numDays = TimeSVDRecommender.days(maxTimestamp, minTimestamp) + 1;
        this.userBiases = new DenseVector(this.numUsers);
        this.userBiases.init();
        this.itemBiases = new DenseVector(this.numItems);
        this.itemBiases.init();
        this.Alpha = new DenseVector(this.numUsers);
        this.Alpha.init();
        this.Bit = new DenseMatrix(this.numItems, this.numBins);
        this.Bit.init();
        this.Y = new DenseMatrix(this.numItems, this.numFactors);
        this.Y.init();
        this.Auk = new DenseMatrix(this.numUsers, this.numFactors);
        this.Auk.init();
        this.But = HashBasedTable.create();
        this.Pukt = new HashMap<Integer, Table<Integer, Integer, Double>>();
        this.Cu = new DenseVector(this.numUsers);
        this.Cu.init();
        this.Cut = new DenseMatrix(this.numUsers, numDays);
        this.Cut.init();
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.P = new DenseMatrix(this.numUsers, this.numFactors);
        this.Q = new DenseMatrix(this.numItems, this.numFactors);
        this.P.init();
        this.Q.init();
        double sum = 0.0;
        int cnt = 0;
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rui = me.get();
            if (rui <= 0.0) continue;
            sum += (double)TimeSVDRecommender.days((long)timeMatrix.get(u, i), minTimestamp);
            ++cnt;
        }
        double globalMeanDate = sum / (double)cnt;
        this.userMeanDate = new DenseVector(this.numUsers);
        List<Integer> Ru = null;
        for (int u = 0; u < this.numUsers; ++u) {
            sum = 0.0;
            try {
                Ru = this.userItemsCache.get(u);
            }
            catch (ExecutionException e) {
                e.printStackTrace();
            }
            Iterator e = Ru.iterator();
            while (e.hasNext()) {
                int i = (Integer)e.next();
                sum += (double)TimeSVDRecommender.days((long)timeMatrix.get(u, i), minTimestamp);
            }
            double mean = Ru.size() > 0 ? (sum + 0.0) / (double)Ru.size() : globalMeanDate;
            this.userMeanDate.set(u, mean);
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int i = me.column();
                double rui = me.get();
                long timestamp = (long)timeMatrix.get(u, i);
                int t = TimeSVDRecommender.days(timestamp, minTimestamp);
                int bin = this.bin(t);
                double dev_ut = this.dev(u, t);
                double bi = this.itemBiases.get(i);
                double bit = this.Bit.get(i, bin);
                double bu = this.userBiases.get(u);
                double cu = this.Cu.get(u);
                double cut = this.Cut.get(u, t);
                if (!this.But.contains(u, t)) {
                    this.But.put(u, t, Randoms.random());
                }
                double but = this.But.get(u, t);
                double au = this.Alpha.get(u);
                double pui = this.globalMean + (bi + bit) * (cu + cut);
                pui += bu + au * dev_ut + but;
                List<Integer> Ru = null;
                try {
                    Ru = this.userItemsCache.get(u);
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
                double sum_y = 0.0;
                for (int j : Ru) {
                    sum_y += DenseMatrix.rowMult(this.Y, j, this.Q, i);
                }
                double wi = Ru.size() > 0 ? Math.pow(Ru.size(), -0.5) : 0.0;
                pui += sum_y * wi;
                if (!this.Pukt.containsKey(u)) {
                    HashBasedTable data = HashBasedTable.create();
                    this.Pukt.put(u, data);
                }
                Table<Integer, Integer, Double> Pkt = this.Pukt.get(u);
                for (int k = 0; k < this.numFactors; ++k) {
                    double qik = this.Q.get(i, k);
                    if (!Pkt.contains(k, t)) {
                        Pkt.put(k, t, Randoms.random());
                    }
                    double puk = this.P.get(u, k) + this.Auk.get(u, k) * dev_ut + Pkt.get(k, t);
                    pui += puk * qik;
                }
                double eui = pui - rui;
                this.loss += eui * eui;
                double sgd = eui * (cu + cut) + this.regBias * bi;
                this.itemBiases.add(i, (double)(-this.learnRate) * sgd);
                this.loss += this.regBias * bi * bi;
                sgd = eui * (cu + cut) + this.regBias * bit;
                this.Bit.add(i, bin, (double)(-this.learnRate) * sgd);
                this.loss += this.regBias * bit * bit;
                sgd = eui * (bi + bit) + this.regBias * cu;
                this.Cu.add(u, (double)(-this.learnRate) * sgd);
                this.loss += this.regBias * cu * cu;
                sgd = eui * (bi + bit) + this.regBias * cut;
                this.Cut.add(u, t, (double)(-this.learnRate) * sgd);
                this.loss += this.regBias * cut * cut;
                sgd = eui + this.regBias * bu;
                this.userBiases.add(u, (double)(-this.learnRate) * sgd);
                this.loss += this.regBias * bu * bu;
                sgd = eui * dev_ut + this.regBias * au;
                this.Alpha.add(u, (double)(-this.learnRate) * sgd);
                this.loss += this.regBias * au * au;
                sgd = eui + this.regBias * but;
                double delta = but - (double)this.learnRate * sgd;
                this.But.put(u, t, delta);
                this.loss += this.regBias * but * but;
                for (int k = 0; k < this.numFactors; ++k) {
                    double qik = this.Q.get(i, k);
                    double puk = this.P.get(u, k);
                    double auk = this.Auk.get(u, k);
                    double pkt = Pkt.get(k, t);
                    double pukt = puk + auk * dev_ut + pkt;
                    double sum_yk = 0.0;
                    for (int j : Ru) {
                        sum_yk += this.Y.get(j, k);
                    }
                    sgd = eui * (pukt + wi * sum_yk) + (double)this.regItem * qik;
                    this.Q.add(i, k, (double)(-this.learnRate) * sgd);
                    this.loss += (double)this.regItem * qik * qik;
                    sgd = eui * qik + (double)this.regUser * puk;
                    this.P.add(u, k, (double)(-this.learnRate) * sgd);
                    this.loss += (double)this.regUser * puk * puk;
                    sgd = eui * qik * dev_ut + (double)this.regUser * auk;
                    this.Auk.add(u, k, (double)(-this.learnRate) * sgd);
                    this.loss += (double)this.regUser * auk * auk;
                    sgd = eui * qik + (double)this.regUser * pkt;
                    delta = pkt - (double)this.learnRate * sgd;
                    Pkt.put(k, t, delta);
                    this.loss += (double)this.regUser * pkt * pkt;
                    for (int j : Ru) {
                        double yjk = this.Y.get(j, k);
                        sgd = eui * wi * qik + (double)this.regItem * yjk;
                        this.Y.add(j, k, (double)(-this.learnRate) * sgd);
                        this.loss += (double)this.regItem * yjk * yjk;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter)) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        long timestamp = (long)timeMatrix.get(userIdx, itemIdx);
        int t = TimeSVDRecommender.days(timestamp, minTimestamp);
        int bin = this.bin(t);
        double dev_ut = this.dev(userIdx, t);
        double pred = this.globalMean;
        pred += (this.itemBiases.get(itemIdx) + this.Bit.get(itemIdx, bin)) * (this.Cu.get(userIdx) + this.Cut.get(userIdx, t));
        double but = this.But.contains(userIdx, t) ? this.But.get(userIdx, t) : 0.0;
        pred += this.userBiases.get(userIdx) + this.Alpha.get(userIdx) * dev_ut + but;
        List<Integer> Ru = null;
        try {
            Ru = this.userItemsCache.get(userIdx);
        }
        catch (ExecutionException e) {
            e.printStackTrace();
        }
        double sum_y = 0.0;
        for (int j : Ru) {
            sum_y += DenseMatrix.rowMult(this.Y, j, this.Q, itemIdx);
        }
        double wi = Ru.size() > 0 ? Math.pow(Ru.size(), -0.5) : 0.0;
        pred += sum_y * wi;
        for (int k = 0; k < this.numFactors; ++k) {
            Table<Integer, Integer, Double> pkt;
            double qik = this.Q.get(itemIdx, k);
            double puk = this.P.get(userIdx, k) + this.Auk.get(userIdx, k) * dev_ut;
            if (this.Pukt.containsKey(userIdx) && (pkt = this.Pukt.get(userIdx)) != null) {
                puk += pkt.contains(k, t) ? pkt.get(k, t) : 0.0;
            }
            pred += puk * qik;
        }
        return pred;
    }

    private double dev(int userId, int t) {
        double tu = this.userMeanDate.get(userId);
        double diff = (double)t - tu;
        return Math.signum(diff) * Math.pow(Math.abs(diff), this.beta);
    }

    private int bin(int day) {
        return (int)((double)day / ((double)numDays + 0.0) * (double)this.numBins);
    }

    private static int days(long diff) {
        return (int)TimeUnit.MILLISECONDS.toDays(diff);
    }

    private static int days(long t1, long t2) {
        return TimeSVDRecommender.days(Math.abs(t1 - t2));
    }

    private void getMaxAndMinTimeStamp() {
        minTimestamp = Long.MAX_VALUE;
        maxTimestamp = Long.MIN_VALUE;
        for (MatrixEntry entry : timeMatrix) {
            long timeStamp = (long)entry.get();
            if (timeStamp < minTimestamp) {
                minTimestamp = timeStamp;
            }
            if (timeStamp <= maxTimestamp) continue;
            maxTimestamp = timeStamp;
        }
    }
}

