jcho02 commited on
Commit
a9b6797
·
verified ·
1 Parent(s): 73b065a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -20
app.py CHANGED
@@ -67,35 +67,49 @@ def predict(audio_data, sampling_rate, config):
67
  input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
68
 
69
  model = SpeechClassifier(config).to(device)
 
70
  model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
71
- model.eval()
72
 
 
73
  with torch.no_grad():
74
  logits = model(input_features, decoder_input_ids)
75
  predicted_ids = int(torch.argmax(logits, dim=-1))
76
  return predicted_ids
77
 
78
- # Unified Gradio interface function
79
- def gradio_interface(audio_input):
80
- if isinstance(audio_input, tuple):
81
- # If the input is a tuple, it's from the microphone
82
- audio_data, sample_rate = audio_input
83
- else:
84
- # Otherwise, it's an uploaded file
85
- with open(audio_input, "rb") as f:
86
- audio_data = np.frombuffer(f.read(), np.int16)
87
- sample_rate = 16000 # Assume 16kHz sample rate for uploaded files
88
-
89
- prediction = predict(audio_data, sample_rate, config)
90
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
91
  return label
92
 
93
- # Create Gradio interface
94
- demo = gr.Interface(
95
- fn=gradio_interface,
96
- inputs=gr.Audio(type="numpy", label="Upload or Record Audio"),
97
- outputs=gr.Textbox(label="Prediction")
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- # Launch the demo
101
  demo.launch(debug=True)
 
67
  input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
68
 
69
  model = SpeechClassifier(config).to(device)
70
+ # Here we load the model from Hugging Face Hub
71
  model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
 
72
 
73
+ model.eval()
74
  with torch.no_grad():
75
  logits = model(input_features, decoder_input_ids)
76
  predicted_ids = int(torch.argmax(logits, dim=-1))
77
  return predicted_ids
78
 
79
+ # Gradio Interface functions
80
+ def gradio_file_interface(uploaded_file):
81
+ # Assuming the uploaded_file is a filepath (str)
82
+ with open(uploaded_file, "rb") as f:
83
+ audio_data = np.frombuffer(f.read(), np.int16)
84
+ prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
85
+ label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
86
+ return label
87
+
88
+ def gradio_mic_interface(mic_input):
89
+ # mic_input is a dictionary with 'data' and 'sample_rate' keys
90
+ prediction = predict(mic_input['data'], mic_input['sample_rate'], config)
91
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
92
  return label
93
 
94
+ # Initialize Blocks
95
+ demo = gr.Blocks()
96
+
97
+ # Define the interfaces inside the Blocks context
98
+ with demo:
99
+ #mic_transcribe = gr.Interface(
100
+ # fn=gradio_mic_interface,
101
+ # inputs=gr.Audio(type="numpy"), # Use numpy for real-time audio like microphone
102
+ # outputs=gr.Textbox(label="Prediction")
103
+ #)
104
+
105
+ file_transcribe = gr.Interface(
106
+ fn=gradio_file_interface,
107
+ inputs=gr.Audio(type="filepath"), # Use filepath for uploaded audio files
108
+ outputs=gr.Textbox(label="Prediction")
109
+ )
110
+
111
+ # Combine interfaces into a tabbed interface
112
+ #gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
113
 
114
+ # Launch the demo with debugging enabled
115
  demo.launch(debug=True)