from madmom.audio.signal import SignalProcessor, FramedSignalProcessor
from madmom.audio.stft import ShortTimeFourierTransformProcessor
from madmom.audio.spectrogram import (
    FilteredSpectrogramProcessor, LogarithmicSpectrogramProcessor,
    SpectrogramDifferenceProcessor)
from madmom.processors import ParallelProcessor, SequentialProcessor
import numpy as np
from BeatNet.common import *


# feature extractor for Magnitude spectrogoram and its differences  

class LOG_SPECT(FeatureModule):
    def __init__(self, num_channels=1, sample_rate=22050, win_length=2048, hop_size=512, n_bands=[12]):
        sig = SignalProcessor(num_channels=num_channels, win_length=win_length, sample_rate=sample_rate)
        self.sample_rate = sample_rate
        self.hop_length = hop_size
        self.num_channels = num_channels
        multi = ParallelProcessor([])
        frame_sizes = [win_length]  # [1024, 2048, 4096]
        num_bands = n_bands  # [3, 6, 12]
        for frame_size, num_bands in zip(frame_sizes, num_bands):
            frames = FramedSignalProcessor(frame_size=frame_size, hop_size=hop_size)
            stft = ShortTimeFourierTransformProcessor()  # caching FFT window
            filt = FilteredSpectrogramProcessor(
                num_bands=num_bands, fmin=30, fmax=17000, norm_filters=True)
            spec = LogarithmicSpectrogramProcessor(mul=1, add=1)
            diff = SpectrogramDifferenceProcessor(
                diff_ratio=0.5, positive_diffs=True, stack_diffs=np.hstack)
            # process each frame size with spec and diff sequentially
            multi.append(SequentialProcessor((frames, stft, filt, spec, diff)))
        # stack the features and processes everything sequentially
        self.pipe = SequentialProcessor((sig, multi, np.hstack))

    def process_audio(self, audio):
        feats = self.pipe(audio)
        return feats.T

# from timeit import default_timer as timer
# import matplotlib.pyplot as plt
# import librosa.display
# # y1, sr1 = librosa.load(librosa.ex('trumpet'),sr=500)
# # mel = MEL(sample_rate=sr1, win_length=32, mel_n_fft=2048, hop_size=10, n_mels=128, fmin=0.0, fmax=None,
# #                  diffs=True)
# # #
# y1, sr1 = librosa.load(librosa.ex('trumpet'),sr=22050)
#
# log_spect = LOG_SPECT(num_channels=1, sample_rate=22050, win_length=2048, hop_size=512, n_bands=[12])
#
# # #
# start = timer()
# S = log_spect.process_audio(y1)
# end = timer()
# print(end - start)
# fig, ax = plt.subplots()
# img = librosa.display.specshow(librosa.amplitude_to_db(S, ref=np.max), y_axis='hz', x_axis='time', ax=ax)
# ax.set_title('Power spectrogram')
# fig.colorbar(img, ax=ax, format="%+2.0f dB")
# plt.show()