"""Test spectral_function.py."""
import numpy as np

from phono3py.phonon3.spectral_function import SpectralFunction

shifts = [
    -0.0049592,
    -0.0049592,
    -0.0120983,
    -0.1226471,
    -0.1214069,
    -0.1214069,
    -0.0051678,
    -0.0051678,
    -0.0128471,
    -0.1224616,
    -0.1200362,
    -0.1200362,
    -0.0055308,
    -0.0055308,
    -0.0122157,
    -0.1093754,
    -0.1077399,
    -0.1077399,
    -0.0037992,
    -0.0037992,
    -0.0089979,
    -0.0955525,
    -0.0958995,
    -0.0958995,
    -0.0034397,
    -0.0034397,
    -0.0107575,
    -0.1068741,
    -0.1067815,
    -0.1067815,
    -0.0017800,
    -0.0017800,
    -0.0102865,
    -0.1348585,
    -0.1275650,
    -0.1275650,
    0.0006728,
    0.0006728,
    -0.0065349,
    -0.2011702,
    -0.2015991,
    -0.2015991,
    0.0021133,
    0.0021133,
    0.0020353,
    -0.0740009,
    -0.0833644,
    -0.0833644,
    0.0037739,
    0.0037739,
    0.0121357,
    0.1597195,
    0.1585307,
    0.1585307,
    0.0026257,
    0.0026257,
    0.0103523,
    0.1626420,
    0.1634832,
    0.1634832,
    -0.0189694,
    -0.0188985,
    -0.0415773,
    -0.0955391,
    -0.1180182,
    -0.1126508,
    -0.0194533,
    -0.0191057,
    -0.0420358,
    -0.0913521,
    -0.1140995,
    -0.1075009,
    -0.0233933,
    -0.0219600,
    -0.0466734,
    -0.0865867,
    -0.1086070,
    -0.1014454,
    -0.0140271,
    -0.0150165,
    -0.0344515,
    -0.0755416,
    -0.1018518,
    -0.0951606,
    -0.0058780,
    -0.0089457,
    -0.0256867,
    -0.0775726,
    -0.1070427,
    -0.1018654,
    -0.0069737,
    -0.0092857,
    -0.0333909,
    -0.1014042,
    -0.1320678,
    -0.1288315,
    -0.0030075,
    -0.0060858,
    -0.0245855,
    -0.1186313,
    -0.1963719,
    -0.1857004,
    0.0058243,
    0.0030539,
    -0.0049966,
    -0.0583228,
    -0.0921850,
    -0.0893692,
    0.0141517,
    0.0149365,
    0.0312156,
    0.0898626,
    0.1454759,
    0.1347802,
    0.0110954,
    0.0137260,
    0.0427527,
    0.1280421,
    0.1715647,
    0.1648037,
]

spec_funcs = [
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000520,
    0.0000520,
    0.0070211,
    0.0003925,
    0.0004139,
    0.0004139,
    0.0000085,
    0.0000085,
    0.0001872,
    0.0016104,
    0.0014443,
    0.0014443,
    0.0000051,
    0.0000051,
    0.0000370,
    0.0027822,
    0.0025951,
    0.0025951,
    0.0000004,
    0.0000004,
    0.0000021,
    0.0197933,
    0.0168956,
    0.0168956,
    0.0000010,
    0.0000010,
    0.0000082,
    0.0080833,
    0.0110838,
    0.0110838,
    0.0000002,
    0.0000002,
    0.0000031,
    0.0014052,
    0.0008202,
    0.0008202,
    0.0000002,
    0.0000002,
    0.0000035,
    0.0037304,
    0.0039325,
    0.0039325,
    0.0000000,
    0.0000000,
    0.0000009,
    0.0009279,
    0.0009800,
    0.0009800,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0101136,
    0.0017460,
    0.0004489,
    0.0005850,
    0.0005048,
    0.0005229,
    0.0002283,
    0.0006942,
    0.0050470,
    0.0012772,
    0.0010867,
    0.0009498,
    0.0002702,
    0.0004036,
    0.0106017,
    0.0086169,
    0.0041489,
    0.0035906,
    0.0000154,
    0.0000295,
    0.0002803,
    0.0434066,
    0.1278558,
    0.0549209,
    0.0000166,
    0.0000264,
    0.0001776,
    0.0018060,
    0.0042557,
    0.0043927,
    0.0000066,
    0.0000108,
    0.0001284,
    0.0010011,
    0.0009471,
    0.0011088,
    0.0000059,
    0.0000105,
    0.0000738,
    0.0010751,
    0.0027300,
    0.0026490,
    0.0000012,
    0.0000033,
    0.0000504,
    0.0005539,
    0.0009128,
    0.0009358,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
    0.0000000,
]


def test_SpectralFunction(si_pbesol):
    """Spectral function of Si."""
    si_pbesol.mesh_numbers = [9, 9, 9]
    si_pbesol.init_phph_interaction()
    sf = SpectralFunction(
        si_pbesol.phph_interaction,
        si_pbesol.grid.grg2bzg[[1, 103]],
        temperatures=[
            300,
        ],
        num_frequency_points=10,
        log_level=1,
    )
    sf.run()

    # for line in np.swapaxes(sf.spectral_functions, -2, -1).reshape(-1, 6):
    #     print(("%.7f, " * 6) % tuple(line))
    # raise

    np.testing.assert_allclose(
        shifts, np.swapaxes(sf.shifts, -2, -1).ravel(), atol=1e-2
    )
    np.testing.assert_allclose(
        spec_funcs,
        np.swapaxes(sf.spectral_functions * np.pi, -2, -1).ravel(),
        atol=1e-2,
        rtol=1e-2,
    )
