|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
|
|
abs_path = os.path.abspath('.') |
|
|
|
base_dir = os.path.dirname(abs_path) |
|
|
|
os.environ['TRANSFORMERS_CACHE'] = os.path.join(base_dir, 'models_cache') |
|
os.environ['TRANSFORMERS_OFFLINE'] = '0' |
|
os.environ['HF_DATASETS_CACHE'] = os.path.join(base_dir, 'datasets_cache') |
|
os.environ['HF_DATASETS_OFFLINE'] = '0' |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"\n\n Device to be used: {device} \n\n") |
|
|
|
|
|
|
|
|
|
model_name = "openai/whisper-small" |
|
|
|
|
|
language = "Bengali" |
|
task = "transcribe" |
|
print(f"\n\n Loading {model_name} for {language} to {task}...this might take a while.. \n\n") |
|
|
|
|
|
|
|
output_dir = "./" |
|
overwrite_output_dir = True |
|
max_steps = 40000 |
|
|
|
per_device_train_batch_size = 4 |
|
|
|
per_device_eval_batch_size = 32 |
|
|
|
gradient_accumulation_steps = 16 |
|
|
|
dataloader_num_workers = 0 |
|
gradient_checkpointing = False |
|
evaluation_strategy ="steps" |
|
|
|
eval_steps = 1000 |
|
save_strategy = "steps" |
|
save_steps = 1000 |
|
|
|
save_total_limit = 5 |
|
learning_rate = 1e-5 |
|
lr_scheduler_type = "cosine" |
|
warmup_steps = 8000 |
|
|
|
logging_steps = 25 |
|
|
|
|
|
weight_decay = 0 |
|
dropout = 0.1 |
|
load_best_model_at_end = True |
|
metric_for_best_model = "wer" |
|
greater_is_better = False |
|
bf16 = True |
|
|
|
tf32 = True |
|
|
|
generation_max_length = 448 |
|
report_to = ["tensorboard"] |
|
predict_with_generate = True |
|
push_to_hub = True |
|
|
|
freeze_feature_encoder = False |
|
early_stopping_patience = 10 |
|
apply_spec_augment = True |
|
torch_compile = False |
|
optim="adamw_hf" |
|
|
|
|
|
|
|
print("\n\n Loading Datasets...this might take a while..\n\n") |
|
|
|
from datasets import load_dataset, DatasetDict, Features, Value, Audio |
|
|
|
common_voice = DatasetDict() |
|
google_fleurs = DatasetDict() |
|
openslr = DatasetDict() |
|
|
|
my_dataset = DatasetDict() |
|
|
|
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "bn", split="train+validation", cache_dir=os.path.join(base_dir, 'datasets_cache')) |
|
google_fleurs["train"] = load_dataset("google/fleurs", "bn_in", split="train+validation", cache_dir=os.path.join(base_dir, 'datasets_cache')) |
|
openslr = load_dataset("openslr", "SLR53", cache_dir=os.path.join(base_dir, 'datasets_cache')) |
|
|
|
|
|
features = Features( |
|
{ |
|
"text": Value("string"), |
|
'path': Value('string'), |
|
"audio": Audio(sampling_rate=16000) |
|
} |
|
) |
|
|
|
crblp = load_dataset( |
|
'csv', |
|
data_files='D:/Govt_Speech_Demo/crblp_speech_corpus/crblp_train.csv', |
|
split='train', |
|
cache_dir=os.path.join(base_dir, 'datasets_cache'), |
|
features=features |
|
) |
|
|
|
|
|
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "bn", split="test", cache_dir=os.path.join(base_dir, 'datasets_cache')) |
|
google_fleurs["test"] = load_dataset("google/fleurs", "bn_in", split="test", cache_dir=os.path.join(base_dir, 'datasets_cache')) |
|
|
|
|
|
|
|
print("\n\n Datasets Loaded \n\n") |
|
print(common_voice) |
|
print(google_fleurs) |
|
print(openslr) |
|
print(crblp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n BEFORE Filtering by Upvotes (Common Voice): \n") |
|
print(common_voice["train"]) |
|
|
|
common_voice["train"] = common_voice["train"].filter(lambda x: (x["up_votes"] - x["down_votes"]) >= 0, num_proc=None) |
|
print("\n AFTER Filtering by Upvotes (Common Voice): \n") |
|
print(common_voice["train"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n So, the datasets to be trained are: \n\n") |
|
print("\n Common Voice 11.0 - Bangla\n") |
|
print(common_voice) |
|
print("\n Google Fleurs - Bangla \n") |
|
print(google_fleurs) |
|
print("\n OpenSLR-53 - Bangla \n") |
|
print(openslr) |
|
print("\n CRBLP - Bangla \n") |
|
print(crblp) |
|
print("\n") |
|
|
|
|
|
|
|
|
|
from datasets import concatenate_datasets, Audio |
|
|
|
sampling_rate = 16000 |
|
|
|
|
|
common_voice = common_voice.cast_column("audio", Audio(sampling_rate)) |
|
google_fleurs = google_fleurs.cast_column("audio", Audio(sampling_rate)) |
|
openslr = openslr.cast_column("audio", Audio(sampling_rate)) |
|
crblp = crblp.cast_column("audio", Audio(sampling_rate)) |
|
|
|
|
|
common_voice = common_voice.remove_columns( |
|
set(common_voice['test'].features.keys()) - {"audio", "sentence"} |
|
) |
|
|
|
google_fleurs = google_fleurs.rename_column("raw_transcription", "sentence") |
|
google_fleurs = google_fleurs.remove_columns( |
|
set(google_fleurs['test'].features.keys()) - {"audio", "sentence"} |
|
) |
|
|
|
openslr = openslr.remove_columns( |
|
set(openslr['train'].features.keys()) - {"audio", "sentence"} |
|
) |
|
|
|
crblp = crblp.rename_column("text", "sentence") |
|
crblp = crblp.remove_columns( |
|
set(crblp.features.keys()) - {"audio", "sentence"} |
|
) |
|
|
|
|
|
|
|
print("\n Checking all audio dtype is float32 or not... \n") |
|
print(f'Common Voice Train: {common_voice["train"][0]["audio"]["array"].dtype}') |
|
print(f'Common Voice Test: {common_voice["test"][0]["audio"]["array"].dtype}') |
|
print(f'Google Fleurs Train: {google_fleurs["train"][0]["audio"]["array"].dtype}') |
|
print(f'Google Fleurs Test: {google_fleurs["test"][0]["audio"]["array"].dtype}') |
|
print(f'OpenSlR: {openslr["train"][0]["audio"]["array"].dtype}') |
|
print(f'CRBLP: {crblp[0]["audio"]["array"].dtype}') |
|
print("\n") |
|
|
|
|
|
|
|
|
|
my_dataset['train'] = concatenate_datasets([common_voice['train'], google_fleurs['train'], openslr['train'], crblp]) |
|
|
|
|
|
|
|
my_dataset['test'] = concatenate_datasets([common_voice['test'], google_fleurs['test']]) |
|
|
|
|
|
|
|
|
|
my_dataset['train'] = my_dataset['train'].shuffle(seed=10) |
|
|
|
print("\n\n AFTER MERGING, train and validation sets are: ") |
|
print(my_dataset) |
|
print("\n") |
|
|
|
|
|
|
|
print("\n\n Augmenting Datasets...this might take a while..\n\n") |
|
from audiomentations import ( |
|
AddBackgroundNoise, |
|
AddGaussianNoise, |
|
Compose, |
|
Gain, |
|
OneOf, |
|
PitchShift, |
|
PolarityInversion, |
|
TimeStretch, |
|
) |
|
|
|
|
|
augmentation = Compose( |
|
[ |
|
TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=False), |
|
Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.1), |
|
PitchShift(min_semitones=-4, max_semitones=4, p=0.2), |
|
AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=1.0), |
|
] |
|
) |
|
|
|
def augment_dataset(batch): |
|
|
|
sample = batch['audio'] |
|
|
|
|
|
augmented_waveform = augmentation(sample["array"], sample_rate=sample["sampling_rate"]) |
|
batch['audio']["array"] = augmented_waveform |
|
return batch |
|
|
|
|
|
augmented_raw_training_dataset = my_dataset["train"].map( |
|
augment_dataset, |
|
num_proc=1, |
|
desc="augment train dataset", |
|
load_from_cache_file=True, |
|
cache_file_name=os.path.join(base_dir, 'datasets_cache', 'augmented_train_cache.arrow') |
|
) |
|
|
|
print("\n COMBINING Augmented Dataset with Normal Dataset..... \n") |
|
|
|
my_dataset["train"] = concatenate_datasets([my_dataset["train"], augmented_raw_training_dataset]) |
|
my_dataset["train"] = my_dataset["train"].shuffle(seed=42) |
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n AFTER AUGMENTATION, FINAL train and validation sets are: ") |
|
print("\n FINAL DATASET: \n") |
|
print(my_dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperTokenizerFast, WhisperProcessor |
|
|
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
|
|
|
processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task) |
|
|
|
|
|
|
|
print("\n\n Preprocessing Datasets...this might take a while..\n\n") |
|
|
|
from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
|
from bnunicodenormalizer import Normalizer |
|
import unicodedata |
|
import re |
|
|
|
do_lower_case = False |
|
do_remove_punctuation = False |
|
do_bangla_unicode_normalization = True |
|
|
|
normalizer = BasicTextNormalizer() |
|
bangla_normalizer = Normalizer(allow_english=True) |
|
|
|
|
|
def removeOptionalZW(text): |
|
""" |
|
Removes all optional occurrences of ZWNJ or ZWJ from Bangla text. |
|
""" |
|
|
|
STANDARDIZE_ZW = re.compile(r'(?<=\u09b0)[\u200c\u200d]+(?=\u09cd\u09af)') |
|
|
|
|
|
DELETE_ZW = re.compile(r'(?<!\u09b0)[\u200c\u200d](?!\u09cd\u09af)') |
|
|
|
text = STANDARDIZE_ZW.sub('\u200D', text) |
|
text = DELETE_ZW.sub('', text) |
|
return text |
|
|
|
|
|
def prepare_dataset(batch): |
|
|
|
audio = batch["audio"] |
|
|
|
|
|
inputs = processor.feature_extractor( |
|
audio["array"], |
|
sampling_rate=audio["sampling_rate"], |
|
return_attention_mask=apply_spec_augment, |
|
) |
|
batch["input_features"] = inputs.input_features[0] |
|
|
|
|
|
batch["input_length"] = len(batch["audio"]) |
|
|
|
|
|
if apply_spec_augment: |
|
batch["attention_mask"] = inputs.get("attention_mask")[0] |
|
|
|
|
|
|
|
transcription = batch["sentence"] |
|
if do_lower_case: |
|
transcription = transcription.lower() |
|
if do_remove_punctuation: |
|
transcription = normalizer(transcription).strip() |
|
if do_bangla_unicode_normalization: |
|
_words = [bangla_normalizer(word)['normalized'] for word in transcription.split()] |
|
transcription = " ".join([word for word in _words if word is not None]) |
|
transcription = transcription.replace("\u2047", "-") |
|
transcription = transcription.replace(u"\u098c", u"\u09ef") |
|
transcription = unicodedata.normalize("NFC", transcription) |
|
transcription = removeOptionalZW(transcription) |
|
|
|
|
|
batch["labels"] = processor.tokenizer(transcription).input_ids |
|
|
|
|
|
batch["labels_length"] = len(batch["labels"]) |
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
my_dataset = my_dataset.map(prepare_dataset, |
|
num_proc=1, |
|
load_from_cache_file=True, |
|
cache_file_names={ |
|
"train" : os.path.join(base_dir, 'datasets_cache', 'preprocessed_train_cache.arrow'), |
|
"test" : os.path.join(base_dir, 'datasets_cache', 'preprocessed_test_cache.arrow'), |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n AFTER PREPROCESSING, final train and validation sets are: ") |
|
print(my_dataset) |
|
print("\n") |
|
|
|
|
|
MAX_DURATION_IN_SECONDS = 30.0 |
|
max_input_length = MAX_DURATION_IN_SECONDS * 16000 |
|
|
|
def filter_inputs(input_length): |
|
"""Filter inputs with zero input length or longer than 30s""" |
|
return 0 < input_length < max_input_length |
|
|
|
my_dataset = my_dataset.filter(filter_inputs, input_columns=["input_length"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n AFTER FILTERING INPUTS, final train and validation sets are: ") |
|
print(my_dataset) |
|
print("\n") |
|
|
|
max_label_length = generation_max_length |
|
|
|
def filter_labels(labels_length): |
|
"""Filter label sequences longer than max length (448)""" |
|
return labels_length < max_label_length |
|
|
|
my_dataset = my_dataset.filter(filter_labels, input_columns=["labels_length"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n AFTER FILTERING LABELS, final train and validation sets are: ") |
|
print(my_dataset) |
|
print("\n") |
|
|
|
|
|
import re |
|
def filter_transcripts(transcript): |
|
"""Filter transcripts with empty strings and samples containing English characters & numbers""" |
|
pattern = r'^.*[a-zA-Z0-9]+.*$' |
|
match = re.match(pattern, transcript) |
|
return len(transcript.split(" ")) > 1 and not bool(match) |
|
|
|
my_dataset = my_dataset.filter(filter_transcripts, input_columns=["sentence"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n AFTER FILTERING TRANSCRIPTS, final train and validation sets are: ") |
|
print("\n My FINAL DATASET \n") |
|
print(my_dataset) |
|
print("\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n Removing UNUSED Cache Files: \n") |
|
try: |
|
print(f"{common_voice.cleanup_cache_files()} for common_voice") |
|
print(f"{google_fleurs.cleanup_cache_files()} for google_fleurs") |
|
print(f"{openslr.cleanup_cache_files()} for openslr") |
|
print(f"{crblp.cleanup_cache_files()} for crblp") |
|
print(f"{my_dataset.cleanup_cache_files()} for my_dataset") |
|
|
|
except Exception as e: |
|
print(f"\n\n UNABLE to REMOVE some Cache files. \n Error: {e} \n\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Union |
|
|
|
@dataclass |
|
class DataCollatorSpeechSeq2SeqWithPadding: |
|
processor: Any |
|
forward_attention_mask: bool |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_features": feature["input_features"]} for feature in features] |
|
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") |
|
|
|
if self.forward_attention_mask: |
|
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) |
|
|
|
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
|
|
|
|
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
|
labels = labels[:, 1:] |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
|
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, forward_attention_mask=apply_spec_augment) |
|
|
|
|
|
|
|
|
|
import evaluate |
|
|
|
wer_metric = evaluate.load("wer", cache_dir=os.path.join(base_dir, "metrics_cache")) |
|
cer_metric = evaluate.load("cer", cache_dir=os.path.join(base_dir, "metrics_cache")) |
|
|
|
do_normalize_eval = True |
|
|
|
def compute_metrics(pred): |
|
pred_ids = pred.predictions |
|
label_ids = pred.label_ids |
|
|
|
|
|
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
|
|
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
|
|
if do_normalize_eval: |
|
pred_str = [normalizer(pred) for pred in pred_str] |
|
label_str = [normalizer(label) for label in label_str] |
|
|
|
wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str) |
|
cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"cer": cer, "wer": wer} |
|
|
|
|
|
|
|
print("\n\n Loading Model to Device..\n\n") |
|
|
|
from transformers import WhisperForConditionalGeneration |
|
|
|
model = WhisperForConditionalGeneration.from_pretrained(model_name) |
|
model = model.to(device) |
|
|
|
|
|
|
|
model.config.apply_spec_augment = apply_spec_augment |
|
model.config.max_length = generation_max_length |
|
model.config.dropout = dropout |
|
model.config.forced_decoder_ids = None |
|
model.config.suppress_tokens = [] |
|
if gradient_checkpointing: |
|
model.config.use_cache = False |
|
if freeze_feature_encoder: |
|
model.freeze_feature_encoder() |
|
|
|
model.generation_config.max_length = generation_max_length |
|
|
|
|
|
from transformers import Seq2SeqTrainingArguments |
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir=output_dir, |
|
overwrite_output_dir=overwrite_output_dir, |
|
max_steps=max_steps, |
|
per_device_train_batch_size=per_device_train_batch_size, |
|
per_device_eval_batch_size=per_device_eval_batch_size, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
gradient_checkpointing=gradient_checkpointing, |
|
dataloader_num_workers=dataloader_num_workers, |
|
evaluation_strategy=evaluation_strategy, |
|
eval_steps=eval_steps, |
|
save_strategy=save_strategy, |
|
save_steps=save_steps, |
|
save_total_limit=save_total_limit, |
|
learning_rate=learning_rate, |
|
lr_scheduler_type=lr_scheduler_type, |
|
warmup_steps=warmup_steps, |
|
logging_steps=logging_steps, |
|
weight_decay=weight_decay, |
|
load_best_model_at_end=load_best_model_at_end, |
|
metric_for_best_model=metric_for_best_model, |
|
greater_is_better=greater_is_better, |
|
bf16=bf16, |
|
tf32=tf32, |
|
torch_compile=torch_compile, |
|
optim=optim, |
|
generation_max_length=generation_max_length, |
|
report_to=report_to, |
|
predict_with_generate=predict_with_generate, |
|
push_to_hub=push_to_hub, |
|
) |
|
|
|
from transformers import Seq2SeqTrainer |
|
import transformers as tf |
|
|
|
trainer = Seq2SeqTrainer( |
|
args=training_args, |
|
model=model, |
|
train_dataset=my_dataset["train"], |
|
eval_dataset=my_dataset["test"], |
|
data_collator=data_collator, |
|
compute_metrics=compute_metrics, |
|
tokenizer=processor.feature_extractor, |
|
callbacks=[tf.EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor.save_pretrained("best_model") |
|
|
|
|
|
|
|
print("\n\n Training STARTED..\n\n") |
|
|
|
train_result = trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n Training COMPLETED...\n\n") |
|
|
|
|
|
|
|
print("\n\n Evaluating Model & Saving Metrics...\n\n") |
|
|
|
processor.save_pretrained(save_directory=output_dir) |
|
|
|
|
|
metrics = train_result.metrics |
|
trainer.save_metrics("train", metrics) |
|
trainer.save_state() |
|
|
|
metrics = trainer.evaluate( |
|
metric_key_prefix="eval", |
|
max_length=training_args.generation_max_length, |
|
num_beams=training_args.generation_num_beams, |
|
) |
|
|
|
trainer.save_metrics("eval", metrics) |
|
|
|
|
|
|
|
if push_to_hub: |
|
print("\n\n Pushing to Hub...\n\n") |
|
|
|
trainer.create_model_card() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.push_to_hub() |
|
|
|
|
|
print("\n\n DONEEEEEE \n\n") |
|
|