multimodal_emotion_recognition / src /audio_processor.py
JuanJoseMV's picture
add model logic implementation
8f96165
raw
history blame
6.12 kB
import time
import torch
import librosa
import numpy as np
import gradio as gr
import gradio as gr
from .generate_graph import create_behaviour_gantt_plot
from transformers import Wav2Vec2Processor
SAMPLING_RATE = 16_000
class AudioProcessor:
def __init__(
self,
emotion_model,
segmentation_model,
device,
behaviour_model=None,
):
self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
self.emotion_model = emotion_model
self.behaviour_model = behaviour_model
self.device = device
self.audio_emotion_labels = {
0: "Neutralità",
1: "Rabbia",
2: "Paura",
3: "Gioia",
4: "Sorpresa",
5: "Tristezza",
6: "Disgusto",
}
self.emotion_translation = {
"neutrality": "Neutralità",
"anger": "Rabbia",
"fear": "Paura",
"joy": "Gioia",
"surprise": "Sorpresa",
"sadness": "Tristezza",
"disgust": "Disgusto"
}
self.behaviour_labels = {
0: "frustrated",
1: "delighted",
2: "dysregulated",
}
self.behaviour_translation = {
"frustrated": "frustazione",
"delighted": "incantato",
"dysregulated": "disregolazione",
}
self.segmentation_model = segmentation_model
self._set_emotion_model()
if self.behaviour_model:
self._set_behaviour_model()
self.behaviour_confidence = 0.6
self.chart_generator = None
def _set_emotion_model(self):
self.emotion_model.to(self.device)
self.emotion_model.eval()
def _set_behaviour_model(self):
self.behaviour_model.to(self.device)
self.behaviour_model.eval()
def _prepare_transcribed_text(self, chunks):
formated_timestamps = []
predictions = []
for chunk in chunks:
start = chunk[0] / SAMPLING_RATE
end = chunk[1] / SAMPLING_RATE
formated_start = time.strftime('%H:%M:%S', time.gmtime(start))
formated_end = time.strftime('%H:%M:%S', time.gmtime(end))
formated_timestamps.append(f"**({formated_start} - {formated_end})**")
predictions.append(f"**[{chunk[2]}]**")
transcribed_texts = [chunk[3] for chunk in chunks]
transcribed_text = "<br/>".join(
[
f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts))
]
)
print(f"Transcribed text:\n{transcribed_text}")
return transcribed_text
def __call__(self, audio_path: str):
"""
Predicts the emotion label for a given audio input.
Args:
audio (filepath): The audio input path to be processed.
Returns:
str: The predicted emotion label.
"""
print("Segmenting audio...")
out = self.segmentation_model(
inputs=audio_path,
return_timestamps=True,
)
emotion_chunks = []
behaviour_chunks = []
timestamps = []
predicted_labels = []
all_probabilities = []
print("Analizing chunks...")
for chunk in out["chunks"]:
# trim audio from timestamps
start = int(chunk["timestamp"][0] * SAMPLING_RATE)
end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames))
audio = input_frames[start:end]
inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE)
print(f"Inputs: {inputs}")
if "input_values" in inputs:
inputs["input_features"] = inputs.pop("input_values")
inputs['input_features'] = inputs['input_features'].to(self.device)
inputs['input_ids'] = inputs['input_ids'].to(self.device)
inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device)
print("Predicting emotion for chunk...")
logits = self.emotion_model(**inputs).logits
logits = logits.detach().cpu()
softmax = torch.nn.Softmax(dim=1)
probabilities = softmax(logits).squeeze(0)
prediction = probabilities.argmax().item()
predicted_label = self.emotion_processor.config.id2label[prediction]
label_translation = self.emotion_translation[predicted_label]
emotion_chunks.append(
(
start,
end,
label_translation,
chunk["text"],
np.round(probabilities[prediction].item(), 2)
)
)
timestamps.append((start, end))
predicted_labels.append(label_translation)
all_probabilities.append(probabilities[prediction].item())
inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE)
if "input_values" in inputs:
inputs["input_features"] = inputs.pop("input_values")
inputs = inputs.input_features.to(self.device)
print("Predicting behaviour for chunk...")
logits = self.behaviour_model(inputs).logits
probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze()
behaviour_chunks.append(
(
start,
end,
chunk["text"],
np.round(probabilities[2].item(), 2),
label_translation,
)
)
behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks)
# transcribed_text = self._prepare_transcribed_text(emotion_chunks)
return (
behaviour_gantt,
# transcribed_text,
)