Kaworu17 commited on
Commit
29e4b0d
·
verified ·
1 Parent(s): e63bfc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -7
app.py CHANGED
@@ -5,6 +5,8 @@ import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import soundfile as sf
7
  from scipy.signal import resample
 
 
8
 
9
  # Load YAMNet model from TensorFlow Hub
10
  yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
@@ -20,12 +22,28 @@ def load_class_map():
20
 
21
  class_names = load_class_map()
22
 
23
- # Classification function
24
- def classify_audio(file_path):
25
  try:
26
- # Load audio
 
 
 
 
 
 
 
 
 
 
 
 
27
  audio_data, sample_rate = sf.read(file_path)
28
 
 
 
 
 
29
  # Convert stereo to mono
30
  if len(audio_data.shape) > 1:
31
  audio_data = np.mean(audio_data, axis=1)
@@ -41,18 +59,19 @@ def classify_audio(file_path):
41
  audio_data = resample(audio_data, new_length)
42
  sample_rate = target_rate
43
 
44
- # Convert to tensor
45
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
46
 
47
- # Run YAMNet
48
  scores, embeddings, spectrogram = yamnet_model(waveform)
49
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
50
  top_5 = np.argsort(mean_scores)[::-1][:5]
51
 
 
52
  top_prediction = class_names[top_5[0]]
53
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
54
 
55
- # Waveform plot
56
  fig, ax = plt.subplots()
57
  ax.plot(audio_data)
58
  ax.set_title("Waveform")
@@ -65,7 +84,7 @@ def classify_audio(file_path):
65
  except Exception as e:
66
  return f"Error processing audio: {str(e)}", {}, None
67
 
68
- # Gradio interface (HF-compatible)
69
  interface = gr.Interface(
70
  fn=classify_audio,
71
  inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),
 
5
  import gradio as gr
6
  import soundfile as sf
7
  from scipy.signal import resample
8
+ import tempfile
9
+ import os
10
 
11
  # Load YAMNet model from TensorFlow Hub
12
  yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
 
22
 
23
  class_names = load_class_map()
24
 
25
+ # Main classification function
26
+ def classify_audio(audio_input):
27
  try:
28
+ # Case 1: Filepath from Gradio UI
29
+ if isinstance(audio_input, str):
30
+ file_path = audio_input
31
+
32
+ # Case 2: Binary upload (n8n POST) without .name attribute
33
+ elif hasattr(audio_input, "read"):
34
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
35
+ tmp.write(audio_input.read())
36
+ file_path = tmp.name
37
+ else:
38
+ raise ValueError("Unsupported input format")
39
+
40
+ # Load audio file
41
  audio_data, sample_rate = sf.read(file_path)
42
 
43
+ # Cleanup if temp file was created
44
+ if 'tmp' in locals():
45
+ os.unlink(tmp.name)
46
+
47
  # Convert stereo to mono
48
  if len(audio_data.shape) > 1:
49
  audio_data = np.mean(audio_data, axis=1)
 
59
  audio_data = resample(audio_data, new_length)
60
  sample_rate = target_rate
61
 
62
+ # Tensor for model
63
  waveform = tf.convert_to_tensor(audio_data, dtype=tf.float32)
64
 
65
+ # Run YAMNet model
66
  scores, embeddings, spectrogram = yamnet_model(waveform)
67
  mean_scores = tf.reduce_mean(scores, axis=0).numpy()
68
  top_5 = np.argsort(mean_scores)[::-1][:5]
69
 
70
+ # Output results
71
  top_prediction = class_names[top_5[0]]
72
  top_scores = {class_names[i]: float(mean_scores[i]) for i in top_5}
73
 
74
+ # Plot waveform
75
  fig, ax = plt.subplots()
76
  ax.plot(audio_data)
77
  ax.set_title("Waveform")
 
84
  except Exception as e:
85
  return f"Error processing audio: {str(e)}", {}, None
86
 
87
+ # Gradio Interface
88
  interface = gr.Interface(
89
  fn=classify_audio,
90
  inputs=gr.Audio(type="filepath", label="Upload .wav or .mp3 audio file"),