# This code is part of Mthree.
#
# (C) Copyright IBM 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=no-name-in-module

"""Test matrix elements"""
import numpy as np

from mthree import M3Mitigation


def test_reduced_matrix():
    """Tests if matrix elements can be computed properly from the data"""

    # Compute using M3
    qubits = [1, 4, 7, 10, 12, 2, 3, 5]
    mit = M3Mitigation(None)
    mit.single_qubit_cals = CALS
    M = mit.reduced_cal_matrix(COUNTS, qubits)[0]
    # Compute using LU solver
    A = np.kron(mit.single_qubit_cals[qubits[1]], mit.single_qubit_cals[qubits[0]])
    for kk in range(2, len(qubits)):
        A = np.kron(mit.single_qubit_cals[qubits[kk]], A)

    assert np.linalg.norm(A-M, np.inf) < 1e-14


def counts_to_vector(counts):
    """ Return probability vector from counts dict.

    Parameters:
        counts (dict): Input dict of counts.

    Returns:
        ndarray: 1D array of probabilities.
    """
    shots = sum(counts.values())
    num_bits = len(next(iter(counts)))
    vec = np.zeros(2**num_bits, dtype=float)
    idx = 0
    for val in counts.values():
        vec[idx] = val / shots
        idx += 1
    return vec


def vector_to_probs(vec, counts):
    """ Return dict of probabilities.

    Parameters:
        vec (ndarray): 1d vector of probabilites.
        counts (dict): Dict of counts

    Returns:
        dict: dict of probabilities
    """
    out_counts = {}
    idx = 0
    for key in counts:
        out_counts[key] = vec[idx]
        idx += 1
    return out_counts


COUNTS = {'00000000': 70,
          '00000001': 33,
          '00010000': 34,
          '00000010': 30,
          '00100000': 36,
          '00000100': 20,
          '01000000': 28,
          '00001000': 31,
          '10000000': 30,
          '00010001': 34,
          '00010010': 45,
          '00010100': 56,
          '00011000': 75,
          '00100001': 50,
          '00100010': 48,
          '00100100': 30,
          '00101000': 28,
          '00000011': 46,
          '00110000': 42,
          '01000001': 59,
          '01000010': 54,
          '01000100': 42,
          '01001000': 42,
          '00000101': 30,
          '01010000': 43,
          '00000110': 47,
          '01100000': 60,
          '10000001': 48,
          '10000010': 57,
          '10000100': 30,
          '10001000': 38,
          '00001001': 33,
          '10010000': 38,
          '00001010': 33,
          '10100000': 37,
          '00001100': 36,
          '11000000': 58,
          '00010011': 22,
          '00010101': 37,
          '00010110': 35,
          '00011001': 50,
          '00011010': 38,
          '00011100': 21,
          '00100011': 37,
          '00100101': 25,
          '00100110': 34,
          '00101001': 23,
          '00101010': 38,
          '00101100': 21,
          '00110001': 34,
          '00110010': 17,
          '00110100': 23,
          '00111000': 41,
          '01000011': 28,
          '01000101': 21,
          '01000110': 23,
          '01001001': 19,
          '01001010': 13,
          '01001100': 25,
          '01010001': 29,
          '01010010': 33,
          '01010100': 26,
          '01011000': 36,
          '01100001': 40,
          '01100010': 38,
          '01100100': 37,
          '01101000': 36,
          '00000111': 26,
          '01110000': 30,
          '10000011': 24,
          '10000101': 20,
          '10000110': 20,
          '10001001': 22,
          '10001010': 18,
          '10001100': 24,
          '10010001': 23,
          '10010010': 15,
          '10010100': 17,
          '10011000': 21,
          '10100001': 32,
          '10100010': 34,
          '10100100': 23,
          '10101000': 16,
          '00001011': 23,
          '10110000': 19,
          '11000001': 24,
          '11000010': 23,
          '11000100': 30,
          '11001000': 23,
          '00001101': 31,
          '11010000': 21,
          '00001110': 30,
          '11100000': 28,
          '00010111': 47,
          '00011011': 43,
          '00011101': 53,
          '00011110': 37,
          '00100111': 40,
          '00101011': 48,
          '00101101': 39,
          '00101110': 32,
          '00110011': 32,
          '00110101': 51,
          '00110110': 44,
          '00111001': 66,
          '00111010': 49,
          '00111100': 40,
          '01000111': 47,
          '01001011': 31,
          '01001101': 33,
          '01001110': 40,
          '01010011': 38,
          '01010101': 47,
          '01010110': 40,
          '01011001': 56,
          '01011010': 38,
          '01011100': 47,
          '01100011': 45,
          '01100101': 38,
          '01100110': 31,
          '01101001': 32,
          '01101010': 38,
          '01101100': 36,
          '01110001': 35,
          '01110010': 28,
          '01110100': 49,
          '01111000': 50,
          '10000111': 36,
          '10001011': 38,
          '10001101': 37,
          '10001110': 33,
          '10010011': 25,
          '10010101': 34,
          '10010110': 35,
          '10011001': 52,
          '10011010': 41,
          '10011100': 34,
          '10100011': 45,
          '10100101': 25,
          '10100110': 34,
          '10101001': 26,
          '10101010': 19,
          '10101100': 39,
          '10110001': 21,
          '10110010': 25,
          '10110100': 31,
          '10111000': 45,
          '11000011': 41,
          '11000101': 26,
          '11000110': 39,
          '11001001': 31,
          '11001010': 29,
          '11001100': 36,
          '11010001': 25,
          '11010010': 36,
          '11010100': 39,
          '11011000': 47,
          '11100001': 53,
          '11100010': 39,
          '11100100': 35,
          '11101000': 28,
          '00001111': 45,
          '11110000': 23,
          '00011111': 36,
          '00101111': 26,
          '00110111': 26,
          '00111011': 37,
          '00111101': 23,
          '00111110': 29,
          '01001111': 16,
          '01010111': 26,
          '01011011': 31,
          '01011101': 17,
          '01011110': 23,
          '01100111': 25,
          '01101011': 25,
          '01101101': 32,
          '01101110': 26,
          '01110011': 30,
          '01110101': 26,
          '01110110': 28,
          '01111001': 30,
          '01111010': 30,
          '01111100': 31,
          '10001111': 23,
          '10010111': 22,
          '10011011': 24,
          '10011101': 21,
          '10011110': 19,
          '10100111': 18,
          '10101011': 12,
          '10101101': 14,
          '10101110': 18,
          '10110011': 26,
          '10110101': 17,
          '10110110': 20,
          '10111001': 28,
          '10111010': 19,
          '10111100': 19,
          '11000111': 16,
          '11001011': 13,
          '11001101': 20,
          '11001110': 28,
          '11010011': 14,
          '11010101': 24,
          '11010110': 24,
          '11011001': 33,
          '11011010': 28,
          '11011100': 18,
          '11100011': 36,
          '11100101': 15,
          '11100110': 15,
          '11101001': 24,
          '11101010': 26,
          '11101100': 20,
          '11110001': 28,
          '11110010': 25,
          '11110100': 27,
          '11111000': 24,
          '00111111': 40,
          '01011111': 38,
          '01101111': 30,
          '01110111': 30,
          '01111011': 50,
          '01111101': 35,
          '01111110': 32,
          '10011111': 26,
          '10101111': 27,
          '10110111': 33,
          '10111011': 46,
          '10111101': 32,
          '10111110': 29,
          '11001111': 31,
          '11010111': 35,
          '11011011': 43,
          '11011101': 36,
          '11011110': 35,
          '11100111': 30,
          '11101011': 23,
          '11101101': 38,
          '11101110': 37,
          '11110011': 30,
          '11110101': 30,
          '11110110': 38,
          '11111001': 36,
          '11111010': 44,
          '11111100': 24,
          '01111111': 27,
          '10111111': 15,
          '11011111': 17,
          '11101111': 16,
          '11110111': 24,
          '11111011': 23,
          '11111101': 17,
          '11111110': 10,
          '11111111': 19}

CALS = [None,
        np.array([[0.98299193, 0.01979335],
                  [0.01700807, 0.98020665]]),
        np.array([[0.96917076, 0.03369085],
                  [0.03082924, 0.96630915]]),
        np.array([[0.9858876, 0.02348826],
                  [0.0141124, 0.97651174]]),
        np.array([[0.99496994, 0.02733885],
                  [0.00503006, 0.97266115]]),
        np.array([[0.96395599, 0.18330204],
                  [0.03604401, 0.81669796]]),
        None,
        np.array([[0.98876682, 0.03722461],
                  [0.01123318, 0.96277539]]),
        None,
        None,
        np.array([[0.99187792, 0.06334372],
                  [0.00812208, 0.93665628]]),
        None,
        np.array([[0.94568855, 0.07140989],
                  [0.05431145, 0.92859011]]),
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None]
