WhisperWithJPDiarization / transcribe_japanese_with_diarization.py
NekoMikoReimu's picture
Upload folder using huggingface_hub
0ed2cac verified
import os
import sys
import torch
import whisper
import re
from pyannote.audio import Pipeline, Audio, Model
from pyannote.audio.pipelines import SpeakerDiarization
def should_skip_line(s: str) -> bool:
parts = s.split(':', 1)
if len(parts) > 1 and parts[1].strip() == '':
return True
phrases_to_skip = ["γ”θ¦–θ΄γ‚γ‚ŠγŒγ¨γ†γ”γ–γ„γΎγ—γŸ", "by H."]
for phrase in phrases_to_skip:
if phrase in s:
return True
return False
def main(audio_file):
# Get HuggingFace token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN_NOT_LOGIN")
if not HF_TOKEN:
print("Error: HF_TOKEN_NOT_LOGIN environment variable is not set.")
sys.exit(1)
# Load pyannote.audio speaker diarization
pretrained_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN)
# Generate Finetune Pipeline
finetuned_model_path = "finetuned_segmentation_model_improved.ckpt"
finetuned_model = Model.from_pretrained(finetuned_model_path)
best_segmentation_threshold = 0.6455219392347773
best_clustering_threshold = 0.6425210602903073
finetuned_pipeline = SpeakerDiarization(
segmentation=finetuned_model,
embedding=r"pyannote/wespeaker-voxceleb-resnet34-LM",
clustering=pretrained_pipeline.klustering,
)
finetuned_pipeline.instantiate({
"segmentation": {
"threshold": best_segmentation_threshold,
"min_duration_off": 0.0,
},
"clustering": {
"method": "centroid",
"min_cluster_size": 15,
"threshold": best_clustering_threshold,
},
})
finetuned_pipeline.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Apply speaker diarization
who_speaks_when = finetuned_pipeline(audio_file)
# Load OpenAI Whisper automatic speech transcription
print("Loading whisper model...")
model = whisper.load_model("large-v2", device="cuda" if torch.cuda.is_available() else "cpu")
print("Whisper model loaded.")
# Transcribe audio
print("Importing Audio!")
audio = Audio(sample_rate=16000, mono=True)
transcribed_lines = []
for segment, _, speaker in who_speaks_when.itertracks(yield_label=True):
waveform, sample_rate = audio.crop(audio_file, segment)
text = model.transcribe(waveform.squeeze().numpy(), language="Japanese")["text"]
timed_line = f"{segment.start:06.1f}s {segment.end:06.1f}s {speaker}: {text}"
print(timed_line)
transcribed_lines.append(timed_line)
# Write transcription to file
output_file = 'timed_script.txt'
with open(output_file, 'w', encoding='UTF-8') as f:
for line in transcribed_lines:
if not should_skip_line(line):
f.write(line + '\n')
print(f"Transcription completed. Output saved to {output_file}")
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python rich_transcription.py <audio_file>")
sys.exit(1)
audio_file = sys.argv[1]
if not os.path.exists(audio_file):
print(f"Error: The file '{audio_file}' does not exist.")
sys.exit(1)
main(audio_file)