Spaces:
Sleeping
Sleeping
import os | |
from pydub import AudioSegment | |
from transformers import pipeline | |
import torch | |
import csv | |
import time | |
from tqdm import tqdm | |
# ======== Configuration ======== | |
# ชื่อโมเดล Whisper ที่ใช้ | |
model_name = "nectec/Pathumma-whisper-th-large-v3" | |
# โฟลเดอร์ไฟล์เสียงต้นทาง | |
input_folder = "/kaggle/input/audio-understanding/speechs/speechs/test" | |
# บันทึกผลลง CSV | |
output_csv = "asr.csv" | |
# =============================== | |
start_time = time.perf_counter() | |
# ตั้งค่าการใช้ GPU / CPU | |
device = 0 if torch.cuda.is_available() else -1 | |
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
# โหลดโมเดล Pathumma | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=model_name, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
# กำหนดภาษาและ task | |
lang = "th" | |
task = "transcribe" | |
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids( | |
language=lang, task=task | |
) | |
# สร้าง list สำหรับเก็บผลลัพธ์ | |
results = [] | |
# วนลูปทุกไฟล์ .wav ในโฟลเดอร์ | |
wav_files = sorted([f for f in os.listdir(input_folder) if f.endswith(".wav")]) | |
for filename in tqdm(wav_files, desc=":open_file_folder: Processing files"): | |
full_path = os.path.join(input_folder, filename) | |
try: | |
audio = AudioSegment.from_file(full_path) | |
except Exception as e: | |
print(f"\n:x: Error loading {filename}: {e}") | |
results.append({"id": filename, "transcription": "[ERROR: Cannot load file]"}) | |
continue | |
chunk_length_ms = 27000 # ตัดเป็นช่วงละ 27 วินาที | |
full_transcription = "" | |
num_chunks = (len(audio) + chunk_length_ms - 1) // chunk_length_ms | |
for i in tqdm( | |
range(num_chunks), desc=f":loud_sound: Chunks for {filename}", leave=False | |
): | |
start = i * chunk_length_ms | |
chunk = audio[start : start + chunk_length_ms] | |
chunk_path = f"temp_chunk_{i}.wav" | |
chunk.export(chunk_path, format="wav") | |
try: | |
output = pipe(chunk_path) | |
full_transcription += output["text"].strip() + " " | |
except Exception as e: | |
print(f"\n:x: Error on chunk {i} of {filename}: {e}") | |
full_transcription += "[ERROR] " | |
os.remove(chunk_path) # ลบไฟล์ชั่วคราว | |
results.append({"id": filename, "transcription": full_transcription.strip()}) | |
with open(output_csv, mode="w", newline="", encoding="utf-8") as file: | |
writer = csv.DictWriter(file, fieldnames=["id", "transcription"]) | |
writer.writeheader() | |
for row in results: | |
writer.writerow(row) | |
end_time = time.perf_counter() | |
elapsed_time = end_time - start_time | |
print(f"\n:white_check_mark: All done! Time taken: {elapsed_time:.2f} seconds") | |
print(f":page_facing_up: Results saved to {output_csv}") | |