import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
import base64
from io import BytesIO
import soundfile as sf
import numpy as np
import IPython.display as ipd
import plotly.graph_objects as go
import plotly
from scipy.signal import csd
from scipy.ndimage import interpolation
from plotly import tools
import scipy.signal as ssig
import matplotlib.pyplot as plt
import librosa
import librosa.display


def plot_stft(audio_file):

    # Load audio file
    audio, sampling_rate = librosa.load(audio_file)

    # Compute stft
    freq, frames, stft = ssig.stft(audio,
                                   sampling_rate,
                                   window='hann',
                                   nperseg=512,
                                   noverlap=412,
                                   nfft=1024,
                                   return_onesided=True,
                                   boundary='zeros',
                                   padded=True,
                                   axis=- 1)


    spectrogram = go.Heatmap(z=librosa.amplitude_to_db(np.abs(stft), ref=np.max),
                                     x=frames,
                                     y=freq,
                                     colorscale='Viridis')



    fig = go.Figure(spectrogram)


    fig.update_layout(
        #width=300,
        #height=500,
        font=dict(family='Latin Modern Roman', size=18),
        xaxis=dict(
            title='Time (seconds)',
            titlefont=dict(family='Latin Modern Roman',
                           size=18)),
        yaxis=dict(
            title='Frequency (Hz)',
            titlefont=dict(family='Latin Modern Roman',
                           size=18)),
        margin=dict(l=0, r=0, t=0, b=0),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)')

    fig.update_traces(colorbar_thickness=8, selector=dict(type='heatmap'))
    fig.update_traces(showscale=True, showlegend=False, visible = True)
    fig.update_xaxes(visible=True, showgrid=False)
    fig.update_yaxes(visible=True, showgrid=False)
    plotly.offline.plot(fig, filename='stft.html', config={'displaylogo': False, 'modeBarButtonsToRemove': ['toImage','zoomIn', 'zoomOut','resetScale']})

    return 'stft.html'

# Gradio interface
demo = gr.Interface(fn=plot_stft, 
                    inputs=gr.Audio(type="filepath"), 
                    outputs="html")

demo.launch()