from math import log2

import librosa
import numpy as np


def _get_n_fft(freq_res_hz: int, sr: int) -> int:
    """
    :freq_res: frequency resolution in Hz = sample_rate / n_fft
              how good you can differentiate between frequency components
              which are at least ‘this’ amount far apart.
    :sr: sampling_rate

    The n_fft specifies the FFT length, i.e. the number of bins.
    Low frequencies are more distinguishable when n_fft is higher.
    For computational reason n_fft is a power of 2 (2, 4, 8, 16, ...)
    """
    return 2 ** round(log2(sr / freq_res_hz))


def get_spectrogram_dB(
    raw_wave: np.ndarray, freq_res_hz: int = 5, sr: int = 12000
) -> np.ndarray:
    spectrogram_complex = librosa.stft(y=raw_wave, n_fft=_get_n_fft(freq_res_hz, sr))
    spectrogram_amplitude = np.abs(spectrogram_complex)
    return librosa.amplitude_to_db(spectrogram_amplitude, ref=np.max)


def get_mel_spectrogram_dB(
    raw_wave: np.ndarray, freq_res_hz: int = 5, sr: int = 12000
) -> np.ndarray:
    spectrogram_complex = librosa.stft(y=raw_wave, n_fft=_get_n_fft(freq_res_hz, sr))
    spectrogram_amplitude = np.abs(spectrogram_complex)
    mel_scale_sepctrogram = librosa.feature.melspectrogram(
        S=spectrogram_amplitude, sr=sr
    )
    return librosa.amplitude_to_db(mel_scale_sepctrogram, ref=np.max)