File size: 2,144 Bytes
e6dfae0
435c886
cd352f8
 
 
f8ceff7
e6dfae0
435c886
954dbdf
cd352f8
 
f8ceff7
cd352f8
 
 
 
 
 
 
 
 
f8ceff7
cd352f8
f8ceff7
cd352f8
f8ceff7
 
 
cd352f8
f8ceff7
cd352f8
 
f8ceff7
cd352f8
 
f8ceff7
 
 
 
cd352f8
 
f8ceff7
 
cd352f8
 
f8ceff7
cd352f8
 
 
f8ceff7
 
2dc1426
 
435c886
 
 
 
cdc9708
e6dfae0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import scipy.signal as ssig
import librosa
import plotly.io as pio

def plot_stft(audio_file):
    # Load audio file
    audio, sampling_rate = librosa.load(audio_file)

    # Compute STFT
    freq, frames, stft = ssig.stft(audio,
                                   sampling_rate,
                                   window='hann',
                                   nperseg=512,
                                   noverlap=412,
                                   nfft=1024,
                                   return_onesided=True,
                                   boundary='zeros',
                                   padded=True,
                                   axis=-1)

    # Create spectrogram heatmap
    spectrogram = go.Heatmap(z=librosa.amplitude_to_db(np.abs(stft), ref=np.max),
                             x=frames,
                             y=freq,
                             colorscale='Viridis')

    # Create Plotly figure
    fig = go.Figure(spectrogram)

    # Customize layout
    fig.update_layout(
        font=dict(family='Latin Modern Roman', size=18),
        xaxis=dict(title='Time (seconds)',
                   titlefont=dict(family='Latin Modern Roman', size=18)),
        yaxis=dict(title='Frequency (Hz)',
                   titlefont=dict(family='Latin Modern Roman', size=18)),
        margin=dict(l=0, r=0, t=0, b=0),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)'
    )

    fig.update_traces(colorbar_thickness=8, selector=dict(type='heatmap'))
    fig.update_traces(showscale=True, showlegend=False, visible=True)
    fig.update_xaxes(visible=True, showgrid=False)
    fig.update_yaxes(visible=True, showgrid=False)

    # Convert the figure to an HTML string
    html_code = pio.to_html(fig, full_html=False, config={'displaylogo': False, 'modeBarButtonsToRemove': ['toImage', 'zoomIn', 'zoomOut', 'resetScale']})

    return html_code

# Gradio interface
demo = gr.Interface(fn=plot_stft, 
                    inputs=gr.Audio(type="filepath"), 
                    outputs="html")

demo.launch()