File size: 1,969 Bytes
e6dfae0
435c886
cd352f8
 
 
f8ceff7
e6dfae0
435c886
954dbdf
cd352f8
 
f8ceff7
cd352f8
 
 
 
 
 
 
 
 
f8ceff7
cd352f8
f8ceff7
cd352f8
f8ceff7
 
 
cd352f8
f8ceff7
cd352f8
 
f8ceff7
cd352f8
 
f8ceff7
 
 
 
cd352f8
f8ceff7
cd352f8
 
f8ceff7
cd352f8
 
 
c277d5c
 
 
2dc1426
c277d5c
435c886
 
c277d5c
 
 
e6dfae0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import scipy.signal as ssig
import librosa
import plotly.io as pio

def plot_stft(audio_file):
    # Load audio file
    audio, sampling_rate = librosa.load(audio_file)

    # Compute STFT
    freq, frames, stft = ssig.stft(audio,
                                   sampling_rate,
                                   window='hann',
                                   nperseg=512,
                                   noverlap=412,
                                   nfft=1024,
                                   return_onesided=True,
                                   boundary='zeros',
                                   padded=True,
                                   axis=-1)

    # Create spectrogram heatmap
    spectrogram = go.Heatmap(z=librosa.amplitude_to_db(np.abs(stft), ref=np.max),
                             x=frames,
                             y=freq,
                             colorscale='Viridis')

    # Create Plotly figure
    fig = go.Figure(spectrogram)

    # Customize layout
    fig.update_layout(
        font=dict(family='Latin Modern Roman', size=18),
        xaxis=dict(title='Time (seconds)',
                   titlefont=dict(family='Latin Modern Roman', size=18)),
        yaxis=dict(title='Frequency (Hz)',
                   titlefont=dict(family='Latin Modern Roman', size=18)),
        margin=dict(l=0, r=0, t=0, b=0),
    )

    fig.update_traces(colorbar_thickness=8, selector=dict(type='heatmap'))
    fig.update_traces(showscale=True, showlegend=False, visible=True)
    fig.update_xaxes(visible=True, showgrid=False)
    fig.update_yaxes(visible=True, showgrid=False)

    # Save the figure as an image
    image_path = 'stft_plot.png'
    fig.write_image(image_path)

    return image_path

# Gradio interface
demo = gr.Interface(fn=plot_stft,
                    inputs=gr.Audio(type="filepath"),
                    outputs="image")

demo.launch()