ShReYas6969 commited on
Commit
249ad34
·
verified ·
1 Parent(s): e6501b4

Update speech_model.py

Browse files
Files changed (1) hide show
  1. speech_model.py +92 -91
speech_model.py CHANGED
@@ -1,91 +1,92 @@
1
- import whisper
2
- from transformers import pipeline
3
- import numpy as np
4
- import os
5
-
6
-
7
- class SpeechEmotionAnalyzer:
8
- """
9
- A class to transcribe audio and classify the emotion from the speech.
10
- """
11
-
12
- def __init__(self, whisper_model="tiny", emotion_model="prithivMLmods/Speech-Emotion-Classification"):
13
- """
14
- Initializes the SpeechEmotionAnalyzer.
15
-
16
- Args:
17
- whisper_model (str): The name of the Whisper model to use for transcription.
18
- emotion_model (str): The Hugging Face model to use for speech emotion classification.
19
- """
20
- # Load the Whisper model for speech-to-text
21
- print("Loading Whisper model...")
22
- self.whisper_model = whisper.load_model(whisper_model)
23
-
24
- # Load the pipeline for audio classification
25
- print("Loading speech emotion classification model...")
26
- self.emotion_classifier = pipeline(
27
- "audio-classification",
28
- model=emotion_model
29
- )
30
- print("SpeechEmotionAnalyzer initialized successfully.")
31
-
32
- def process_audio(self, audio_data: np.ndarray, sample_rate: int) -> tuple[str, str | None]:
33
- """
34
- Transcribes audio and classifies its emotion.
35
-
36
- Args:
37
- audio_data (np.ndarray): The raw audio data as a NumPy array.
38
- sample_rate (int): The sample rate of the audio data.
39
-
40
- Returns:
41
- tuple[str, str | None]: A tuple containing:
42
- - The transcribed text.
43
- - The detected emotion label (e.g., 'SAD', 'HAPPY') or None if classification fails.
44
- """
45
- # Ensure audio is in the correct format (float32) for Whisper
46
- if audio_data.dtype != np.float32:
47
- audio_data = audio_data.astype(np.float32) / 32767.0
48
-
49
- # 1. Transcribe audio to text using Whisper
50
- print("Transcribing audio...")
51
- transcription_result = self.whisper_model.transcribe(audio_data)
52
- text = transcription_result.get("text", "").strip()
53
-
54
- # 2. Classify emotion from the audio
55
- print("Classifying speech emotion...")
56
- try:
57
- # The pipeline expects a dictionary with 'raw' audio data and 'sampling_rate'
58
- audio_input = {"raw": audio_data, "sampling_rate": sample_rate}
59
- emotion_results = self.emotion_classifier(audio_input, top_k=1)
60
-
61
- # The result is a list of lists, get the top result
62
- if emotion_results and emotion_results[0]:
63
- emotion = emotion_results[0][0]['label']
64
- else:
65
- emotion = None
66
- except Exception as e:
67
- print(f"Could not classify speech emotion: {e}")
68
- emotion = None
69
-
70
- return text, emotion
71
-
72
-
73
- if __name__ == '__main__':
74
- # Example usage: This part is harder to test standalone without an audio file.
75
- # The main.py script will handle live microphone input.
76
- # You can uncomment and modify the following to test with a local audio file.
77
-
78
- # from scipy.io.wavfile import read
79
- # try:
80
- # analyzer = SpeechEmotionAnalyzer()
81
- # # Make sure you have a 'test_audio.wav' file in the same directory.
82
- # sample_rate, audio_data = read("test_audio.wav")
83
- # text, emotion = analyzer.process_audio(audio_data, sample_rate)
84
- # print("--- Analysis Result ---")
85
- # print(f"Transcription: {text}")
86
- # print(f"Vocal Emotion: {emotion}")
87
- # except FileNotFoundError:
88
- # print("Could not find 'test_audio.wav'. Skipping standalone test.")
89
- # except Exception as e:
90
- # print(f"An error occurred during standalone test: {e}")
91
- pass
 
 
1
+ # speech_model.py
2
+
3
+ import whisper
4
+ from transformers import pipeline
5
+ import numpy as np
6
+ import os
7
+ from typing import Union, Tuple
8
+
9
+ class SpeechEmotionAnalyzer:
10
+ """
11
+ A class to transcribe audio and classify the emotion from the speech.
12
+ """
13
+ def __init__(self, whisper_model="tiny", emotion_model="prithivMLmods/Speech-Emotion-Classification"):
14
+ """
15
+ Initializes the SpeechEmotionAnalyzer.
16
+
17
+ Args:
18
+ whisper_model (str): The name of the Whisper model to use for transcription.
19
+ emotion_model (str): The Hugging Face model to use for speech emotion classification.
20
+ """
21
+ # Load the Whisper model for speech-to-text
22
+ print("Loading Whisper model...")
23
+ self.whisper_model = whisper.load_model(whisper_model)
24
+
25
+ # Load the pipeline for audio classification
26
+ print("Loading speech emotion classification model...")
27
+ self.emotion_classifier = pipeline(
28
+ "audio-classification",
29
+ model=emotion_model
30
+ )
31
+ print("SpeechEmotionAnalyzer initialized successfully.")
32
+
33
+ def process_audio(self, audio_data: np.ndarray, sample_rate: int) -> Tuple[str, Union[str, None]]:
34
+ """
35
+ Transcribes audio and classifies its emotion.
36
+
37
+ Args:
38
+ audio_data (np.ndarray): The raw audio data as a NumPy array.
39
+ sample_rate (int): The sample rate of the audio data.
40
+
41
+ Returns:
42
+ A tuple containing:
43
+ - The transcribed text.
44
+ - The detected emotion label (e.g., 'SAD', 'HAPPY') or None if classification fails.
45
+ """
46
+ # Ensure audio is in the correct format (float32) for Whisper
47
+ if audio_data.dtype != np.float32:
48
+ audio_data = audio_data.astype(np.float32) / 32767.0
49
+
50
+ # 1. Transcribe audio to text using Whisper
51
+ print("Transcribing audio...")
52
+ transcription_result = self.whisper_model.transcribe(audio_data)
53
+ text = transcription_result.get("text", "").strip()
54
+
55
+ # 2. Classify emotion from the audio
56
+ print("Classifying speech emotion...")
57
+ try:
58
+ # The pipeline expects a dictionary with 'raw' audio data and 'sampling_rate'
59
+ audio_input = {"raw": audio_data, "sampling_rate": sample_rate}
60
+ emotion_results = self.emotion_classifier(audio_input, top_k=1)
61
+
62
+ # The result is a list of lists, get the top result
63
+ if emotion_results and emotion_results[0]:
64
+ emotion = emotion_results[0][0]['label']
65
+ else:
66
+ emotion = None
67
+ except Exception as e:
68
+ print(f"Could not classify speech emotion: {e}")
69
+ emotion = None
70
+
71
+ return text, emotion
72
+
73
+
74
+ if __name__ == '__main__':
75
+ # Example usage: This part is harder to test standalone without an audio file.
76
+ # The main.py script will handle live microphone input.
77
+ # You can uncomment and modify the following to test with a local audio file.
78
+
79
+ # from scipy.io.wavfile import read
80
+ # try:
81
+ # analyzer = SpeechEmotionAnalyzer()
82
+ # # Make sure you have a 'test_audio.wav' file in the same directory.
83
+ # sample_rate, audio_data = read("test_audio.wav")
84
+ # text, emotion = analyzer.process_audio(audio_data, sample_rate)
85
+ # print("--- Analysis Result ---")
86
+ # print(f"Transcription: {text}")
87
+ # print(f"Vocal Emotion: {emotion}")
88
+ # except FileNotFoundError:
89
+ # print("Could not find 'test_audio.wav'. Skipping standalone test.")
90
+ # except Exception as e:
91
+ # print(f"An error occurred during standalone test: {e}")
92
+ pass