Spaces:
Running
on
Zero
Running
on
Zero
| from pydub import AudioSegment | |
| import os | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer | |
| import torchaudio | |
| import torch | |
| import re | |
| from transformers import pipeline | |
| from peft import PeftModel, PeftConfig | |
| import spaces | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float32 | |
| ### Configuration | |
| MODEL_NAME_V2 = "./whisper-large-v3-catalan" | |
| MODEL_NAME_V1 = "projecte-aina/whisper-large-v3-tiny-caesar" | |
| CHUNK_LENGTH = 30 | |
| BATCH_SIZE = 1 | |
| pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL_NAME_V1, | |
| chunk_length_s=30, | |
| device=device, | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| peft_config = PeftConfig.from_pretrained(MODEL_NAME_V2) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| peft_config.base_model_name_or_path, | |
| device_map="auto" | |
| ) | |
| task = "transcribe" | |
| model = PeftModel.from_pretrained(model, MODEL_NAME_V2) | |
| model.config.use_cache = True | |
| tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, task=task) | |
| processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, task=task) | |
| feature_extractor = processor.feature_extractor | |
| forced_decoder_ids = processor.get_decoder_prompt_ids(task=task) | |
| asr_pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=model, | |
| tokenizer=tokenizer, | |
| feature_extractor=feature_extractor, | |
| chunk_length_s=30) | |
| def asr(audio_path, task): | |
| asr_result = asr_pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task":task}, return_timestamps=True) | |
| base_model = asr_pipe.model.base_model if hasattr(asr_pipe.model, "base_model") else asr_pipe.model | |
| return asr_result | |
| def post_process_transcription(transcription, max_repeats=2): | |
| tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) | |
| cleaned_tokens = [] | |
| repetition_count = 0 | |
| previous_token = None | |
| for token in tokens: | |
| reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) | |
| if reduced_token == previous_token: | |
| repetition_count += 1 | |
| if repetition_count <= max_repeats: | |
| cleaned_tokens.append(reduced_token) | |
| else: | |
| repetition_count = 1 | |
| cleaned_tokens.append(reduced_token) | |
| previous_token = reduced_token | |
| cleaned_transcription = " ".join(cleaned_tokens) | |
| cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() | |
| return cleaned_transcription | |
| def format_audio(audio_path): | |
| input_audio, sample_rate = torchaudio.load(audio_path) | |
| if input_audio.shape[0] == 2: #stereo2mono | |
| input_audio = torch.mean(input_audio, dim=0, keepdim=True) | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| input_audio = resampler(input_audio) | |
| input_audio = input_audio.squeeze().numpy() | |
| return(input_audio) | |
| def split_stereo_channels(audio_path): | |
| audio = AudioSegment.from_wav(audio_path) | |
| channels = audio.split_to_mono() | |
| if len(channels) != 2: | |
| raise ValueError(f"Audio {audio_path} does not have 2 channels.") | |
| channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right | |
| channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left | |
| def transcribe_pipeline(audio, task): | |
| text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"] | |
| return text | |
| def generate(audio_path, use_v2): | |
| if use_v2: | |
| split_stereo_channels(audio_path) | |
| audio_id = os.path.splitext(os.path.basename(audio_path))[0] | |
| left_channel_path = "temp_mono_speaker2.wav" | |
| right_channel_path = "temp_mono_speaker1.wav" | |
| left_audio = format_audio(left_channel_path) | |
| right_audio = format_audio(right_channel_path) | |
| left_result = asr(left_audio, task) | |
| right_result = asr(right_audio, task) | |
| left_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 1", post_process_transcription(seg["text"])) for seg in left_result["chunks"]] | |
| right_segs = [(seg["timestamp"][0], seg["timestamp"][1], "Speaker 2", post_process_transcription(seg["text"])) for seg in right_result["chunks"]] | |
| merged_transcript = sorted(left_segs + right_segs, key=lambda x: x[0]) | |
| merged_text = " ".join([seg[3] for seg in merged_transcript]) | |
| output = "" | |
| for start, end, speaker, text in merged_transcript: | |
| output += f"[{start:.2f}s - {end:.2f}s] {speaker}: {text}\n" | |
| else: | |
| audio = AudioSegment.from_wav(audio_path) | |
| temp_mono_path = None | |
| if audio.channels != 1: #stereo2mono | |
| audio = audio.set_channels(1) | |
| temp_mono_path = "temp_mono.wav" | |
| audio.export(temp_mono_path, format="wav") | |
| audio_path = temp_mono_path | |
| task = "transcribe" | |
| output = transcribe_pipeline(format_audio(audio_path), task) | |
| clean_output = post_process_transcription(output, max_repeats=1) #check | |
| if temp_mono_path and os.path.exists(temp_mono_path): | |
| os.remove(temp_mono_path) | |
| for temp_file in ["temp_mono_speaker1.wav", "temp_mono_speaker2.wav"]: | |
| if os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| return clean_output |