/*
 * Decompiled with CFR 0.152.
 */
package net.librec.math.structure;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import net.librec.math.structure.SparseMatrix;

public class DiagMatrix
extends SparseMatrix {
    private static final long serialVersionUID = -9186836460633909994L;

    public DiagMatrix(int rows, int cols, Table<Integer, Integer, Double> dataTable, Multimap<Integer, Integer> colMap) {
        super(rows, cols, dataTable, colMap);
    }

    public DiagMatrix(DiagMatrix mat) {
        super(mat);
    }

    @Override
    public DiagMatrix clone() {
        return new DiagMatrix(this);
    }

    public DiagMatrix scale(double val) {
        DiagMatrix res = this.clone();
        for (int i = 0; i < res.numRows; ++i) {
            res.set(i, i, this.get(i, i) * val);
        }
        return res;
    }

    public DiagMatrix scaleEqual(double val) {
        for (int i = 0; i < this.numRows; ++i) {
            this.set(i, i, this.get(i, i) * val);
        }
        return this;
    }

    public DiagMatrix add(DiagMatrix that) {
        assert (this.size() == that.size());
        DiagMatrix res = this.clone();
        for (int i = 0; i < res.numRows; ++i) {
            res.set(i, i, this.get(i, i) + that.get(i, i));
        }
        return res;
    }

    public DiagMatrix addEqual(DiagMatrix that) {
        assert (this.size() == that.size());
        for (int i = 0; i < this.numRows; ++i) {
            this.set(i, i, this.get(i, i) + that.get(i, i));
        }
        return this;
    }

    public DiagMatrix add(double val) {
        DiagMatrix res = this.clone();
        for (int i = 0; i < res.numRows; ++i) {
            res.set(i, i, this.get(i, i) + val);
        }
        return res;
    }

    public DiagMatrix addEqual(double val) {
        for (int i = 0; i < this.numRows; ++i) {
            this.set(i, i, this.get(i, i) + val);
        }
        return this;
    }

    public DiagMatrix minus(DiagMatrix that) {
        assert (this.size() == that.size());
        DiagMatrix res = this.clone();
        for (int i = 0; i < res.numRows; ++i) {
            res.set(i, i, this.get(i, i) - that.get(i, i));
        }
        return res;
    }

    public DiagMatrix minusEqual(DiagMatrix that) {
        assert (this.size() == that.size());
        for (int i = 0; i < this.numRows; ++i) {
            this.set(i, i, this.get(i, i) - that.get(i, i));
        }
        return this;
    }

    public DiagMatrix minus(double val) {
        DiagMatrix res = this.clone();
        for (int i = 0; i < res.numRows; ++i) {
            res.set(i, i, this.get(i, i) - val);
        }
        return res;
    }

    public DiagMatrix minusEqual(double val) {
        for (int i = 0; i < this.numRows; ++i) {
            this.set(i, i, this.get(i, i) - val);
        }
        return this;
    }

    public static DiagMatrix eye(int n) {
        HashBasedTable<Integer, Integer, Double> vals = HashBasedTable.create();
        HashMultimap<Integer, Integer> colMap = HashMultimap.create();
        for (int i = 0; i < n; ++i) {
            vals.put(i, i, 1.0);
            colMap.put(i, i);
        }
        return new DiagMatrix(n, n, (Table<Integer, Integer, Double>)vals, (Multimap<Integer, Integer>)colMap);
    }
}

