import numpy as np
import librosa
import torch
import torch.nn as nn
# import pywt
from scipy import signal



def compute_cwt_power_spectrum(audio, sample_rate, num_freqs=128, f_min=20, f_max=None):
    """
    Compute the power spectrum of continuous wavelet transform using Morlet wavelet.

    Parameters:
        audio: torch.Tensor
            Input audio signal
        sample_rate: int
            Sampling rate of the audio
        num_freqs: int
            Number of frequency bins for the CWT
        f_min: float
            Minimum frequency to analyze
        f_max: float or None
            Maximum frequency to analyze (defaults to Nyquist frequency)

    Returns:
        torch.Tensor: CWT power spectrum
    """
    # Convert to numpy
    audio_np = audio.cpu().numpy()

    # Set default f_max to Nyquist frequency if not specified
    if f_max is None:
        f_max = sample_rate // 2

    # Generate frequency bins (logarithmically spaced)
    frequencies = np.logspace(
        np.log10(f_min),
        np.log10(f_max),
        num=num_freqs
    )

    # Compute the width of the wavelet (in samples)
    widths = sample_rate / (2 * frequencies * np.pi)

    # Compute CWT using Morlet wavelet
    cwt = signal.cwt(
        audio_np,
        signal.morlet2,
        widths,
        w=5.0  # Width parameter of Morlet wavelet
    )

    # Compute power spectrum (magnitude squared)
    power_spectrum = np.abs(cwt) ** 2

    # Convert to torch tensor
    power_spectrum_tensor = torch.FloatTensor(power_spectrum)

    return power_spectrum_tensor

# def compute_wavelet_transform(audio, wavelet, decompos_level):
#     """Compute wavelet decomposition of the audio signal."""
#     # Convert to numpy and ensure 1D
#     audio_np = audio.cpu().numpy()
#
#     # Perform wavelet decomposition
#     coeffs = pywt.wavedec(audio_np, wavelet, level=decompos_level)
#
#     # Stack coefficients into a 2D array
#     # First, pad all coefficient arrays to the same length
#     max_len = max(len(c) for c in coeffs)
#     padded_coeffs = []
#     for coeff in coeffs:
#         pad_len = max_len - len(coeff)
#         if pad_len > 0:
#             padded_coeff = np.pad(coeff, (0, pad_len), mode='constant')
#         else:
#             padded_coeff = coeff
#         padded_coeffs.append(padded_coeff)
#
#     # Stack into 2D array where each row is a different scale
#     wavelet_features = np.stack(padded_coeffs)
#
#     # Convert to tensor
#     return torch.FloatTensor(wavelet_features)


def compute_melspectrogram(audio, sample_rate):
    mel_spec = librosa.feature.melspectrogram(
        y=audio.cpu().numpy(),
        sr=sample_rate,
        n_mels=128
    )
    return torch.FloatTensor(librosa.power_to_db(mel_spec))


def compute_mfcc(audio, sample_rate):
    mfcc = librosa.feature.mfcc(
        y=audio.cpu().numpy(),
        sr=sample_rate,
        n_mfcc=20
    )
    return torch.FloatTensor(mfcc)


def compute_chroma(audio, sample_rate):
    chroma = librosa.feature.chroma_stft(
        y=audio.cpu().numpy(),
        sr=sample_rate
    )
    return torch.FloatTensor(chroma)


def compute_time_domain_features(audio, sample_rate, frame_length=2048, hop_length=128):
    """
    Compute time-domain features from audio signal.
    Returns a dictionary of features.
    """
    # Convert to numpy
    audio_np = audio.cpu().numpy()

    # Initialize dictionary for features
    features = {}

    # 1. Zero Crossing Rate
    zcr = librosa.feature.zero_crossing_rate(
        y=audio_np,
        frame_length=frame_length,
        hop_length=hop_length
    )
    features['zcr'] = torch.Tensor([zcr.sum()])

    # 2. Root Mean Square Energy
    rms = librosa.feature.rms(
        y=audio_np,
        frame_length=frame_length,
        hop_length=hop_length
    )
    features['rms_energy'] = torch.Tensor([rms.mean()])

    # 3. Temporal Statistics
    frames = librosa.util.frame(audio_np, frame_length=frame_length, hop_length=hop_length)
    features['mean'] = torch.Tensor([np.mean(frames, axis=0).mean()])
    features['std'] = torch.Tensor([np.std(frames, axis=0).mean()])
    features['max'] = torch.Tensor([np.max(frames, axis=0).mean()])

    # 4. Tempo and Beat Features
    onset_env = librosa.onset.onset_strength(y=audio_np, sr=sample_rate)
    tempo = librosa.beat.tempo(onset_envelope=onset_env, sr=sample_rate)
    features['tempo'] = torch.Tensor(tempo)

    # 5. Amplitude Envelope
    envelope = np.abs(librosa.stft(audio_np, n_fft=frame_length, hop_length=hop_length))
    features['envelope'] = torch.Tensor([np.mean(envelope, axis=0).mean()])

    return features

def compute_frequency_domain_features(audio, sample_rate, n_fft=2048, hop_length=512):
    """
    Compute frequency-domain features from audio signal.
    Returns a dictionary of features.
    """
    # Convert to numpy
    audio_np = audio.cpu().numpy()

    # Initialize dictionary for features
    features = {}

    # 1. Spectral Centroid
    try:
        spectral_centroids = librosa.feature.spectral_centroid(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
        )
        features['spectral_centroid'] = torch.FloatTensor([spectral_centroids.max()])
    except Exception as e:
        features['spectral_centroid'] = torch.FloatTensor([np.nan])

    # 2. Spectral Rolloff
    try:
        spectral_rolloff = librosa.feature.spectral_rolloff(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
        )
        features['spectral_rolloff'] = torch.FloatTensor([spectral_rolloff.max()])
    except Exception as e:
        features['spectral_rolloff'] = torch.FloatTensor([np.nan])

    # 3. Spectral Bandwidth
    try:
        spectral_bandwidth = librosa.feature.spectral_bandwidth(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length
        )
        features['spectral_bandwidth'] = torch.FloatTensor([spectral_bandwidth.max()])
    except Exception as e:
        features['spectral_bandwidth'] = torch.FloatTensor([np.nan])

    # 4. Spectral Contrast
    try:
        spectral_contrast = librosa.feature.spectral_contrast(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            fmin=20,  # Lower minimum frequency
            n_bands=4,  # Reduce number of bands
            quantile=0.02
        )
        features['spectral_contrast'] = torch.FloatTensor([spectral_contrast.mean()])
    except Exception as e:
        features['spectral_contrast'] = torch.FloatTensor([np.nan])

    # 5. Spectral Flatness
    try:
        spectral_flatness = librosa.feature.spectral_flatness(
            y=audio_np,
            n_fft=n_fft,
            hop_length=hop_length
        )
        features['spectral_flatness'] = torch.FloatTensor([spectral_flatness.max()])
    except Exception as e:
        features['spectral_flatness'] = torch.FloatTensor([np.nan])

    # 6. Spectral Flux
    try:
        stft = np.abs(librosa.stft(audio_np, n_fft=n_fft, hop_length=hop_length))
        spectral_flux = np.diff(stft, axis=1)
        spectral_flux = np.pad(spectral_flux, ((0, 0), (1, 0)), mode='constant')
        features['spectral_flux'] = torch.FloatTensor([np.std(spectral_flux)])
    except Exception as e:
        features['spectral_flux'] = torch.FloatTensor([np.nan])

    # 7. MFCCs (Mel-Frequency Cepstral Coefficients)
    try:
        mfccs = librosa.feature.mfcc(
            y=audio_np,
            sr=sample_rate,
            n_mfcc=13,  # Number of MFCCs to compute
            n_fft=n_fft,
            hop_length=hop_length
        )
        features['mfcc_mean'] = torch.FloatTensor([mfccs.mean()])
    except Exception as e:
        features['mfcc_mean'] = torch.FloatTensor([np.nan])

    # 8. Chroma Features
    try:
        chroma = librosa.feature.chroma_stft(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length
        )
        features['chroma_mean'] = torch.FloatTensor([chroma.mean()])
    except Exception as e:
        features['chroma_mean'] = torch.FloatTensor([np.nan])

    # 9. Spectral Kurtosis
    try:
        spectral_kurtosis = librosa.feature.spectral_kurtosis(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length
        )
        features['spectral_kurtosis'] = torch.FloatTensor([spectral_kurtosis.mean()])
    except Exception as e:
        features['spectral_kurtosis'] = torch.FloatTensor([np.nan])

    # 10. Spectral Skewness
    try:
        spectral_skewness = librosa.feature.spectral_skewness(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length
        )
        features['spectral_skewness'] = torch.FloatTensor([spectral_skewness.mean()])
    except Exception as e:
        features['spectral_skewness'] = torch.FloatTensor([np.nan])

    # 11. Spectral Slope
    try:
        spectral_slope = librosa.feature.spectral_slope(
            y=audio_np,
            sr=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length
        )
        features['spectral_slope'] = torch.FloatTensor([spectral_slope.mean()])
    except Exception as e:
        features['spectral_slope'] = torch.FloatTensor([np.nan])

    # 12. Tonnetz (Tonal Centroid Features)
    try:
        tonnetz = librosa.feature.tonnetz(
            y=audio_np,
            sr=sample_rate
        )
        features['tonnetz_mean'] = torch.FloatTensor([tonnetz.mean()])
    except Exception as e:
        features['tonnetz_mean'] = torch.FloatTensor([np.nan])

    return features


def compute_all_features(audio, sample_rate, wavelet='db1', decompos_level=4):
    """
    Compute all available features and return them in a dictionary.
    """
    features = {}

    # Basic transformations
    # features['wavelet'] = compute_wavelet_transform(audio, wavelet, decompos_level)
    # features['melspectrogram'] = compute_melspectrogram(audio, sample_rate)
    # features['mfcc'] = compute_mfcc(audio, sample_rate)
    # features['chroma'] = compute_chroma(audio, sample_rate)

    # features['cwt_power'] = compute_cwt_power_spectrum(
    #     audio,
    #     sample_rate,
    #     num_freqs=128,  # Same as mel bands for consistency
    #     f_min=20,  # Standard lower frequency bound
    #     f_max=sample_rate // 2  # Nyquist frequency
    # )

    # Time domain features
    # features['time_domain'] = compute_time_domain_features(audio, sample_rate)

    # Frequency domain features
    return compute_frequency_domain_features(audio, sample_rate)