jhauret commited on
Commit
954dbdf
·
1 Parent(s): cdc9708

html with matplotlib

Browse files
Files changed (2) hide show
  1. app.py +23 -22
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,38 +1,39 @@
1
  import gradio as gr
2
  import numpy as np
3
- import plotly.graph_objects as go
4
  import librosa
 
 
 
5
 
6
  def plot_stft(audio_file):
7
- # Load the audio file
8
  y, sr = librosa.load(audio_file)
9
 
10
- # Compute the STFT
11
  D = librosa.stft(y)
12
  S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
13
 
14
- # Generate time and frequency axes
15
- times = librosa.times_like(S_db)
16
- freqs = librosa.fft_frequencies(sr=sr)
 
 
17
 
18
- # Create Plotly figure
19
- fig = go.Figure(data=go.Heatmap(
20
- z=S_db,
21
- x=times,
22
- y=freqs,
23
- colorscale='Viridis'
24
- ))
25
 
26
- # Update layout for better visualization
27
- fig.update_layout(
28
- title="STFT (Short-Time Fourier Transform)",
29
- xaxis_title="Time (s)",
30
- yaxis_title="Frequency (Hz)",
31
- yaxis_type="log"
32
- )
33
 
34
- # Return the HTML representation of the plot
35
- return fig.to_html()
 
 
36
 
37
  # Gradio interface
38
  demo = gr.Interface(fn=plot_stft,
 
1
  import gradio as gr
2
  import numpy as np
3
+ import matplotlib.pyplot as plt
4
  import librosa
5
+ import librosa.display
6
+ import base64
7
+ from io import BytesIO
8
 
9
  def plot_stft(audio_file):
10
+ # Load audio file
11
  y, sr = librosa.load(audio_file)
12
 
13
+ # Compute the Short-Time Fourier Transform (STFT)
14
  D = librosa.stft(y)
15
  S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
16
 
17
+ # Plot the STFT
18
+ plt.figure(figsize=(10, 6))
19
+ librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='log')
20
+ plt.colorbar(format='%+2.0f dB')
21
+ plt.title('STFT (Short-Time Fourier Transform)')
22
 
23
+ # Save the plot to a BytesIO object
24
+ buf = BytesIO()
25
+ plt.savefig(buf, format='png')
26
+ plt.close()
27
+ buf.seek(0)
 
 
28
 
29
+ # Encode the image as base64
30
+ image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
31
+ buf.close()
 
 
 
 
32
 
33
+ # Create an HTML img tag with the base64 encoded image
34
+ html_img = f'<img src="data:image/png;base64,{image_base64}" alt="STFT plot"/>'
35
+
36
+ return html_img
37
 
38
  # Gradio interface
39
  demo = gr.Interface(fn=plot_stft,
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
- plotly
 
 
2
  librosa
 
 
1
+ gradio
2
+ numpy
3
+ matplotlib
4
  librosa
5
+