Commit
·
e5deaf1
1
Parent(s):
afab1b0
training script
Browse files- telugu_xlmr.py +246 -0
telugu_xlmr.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
import re
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
|
10 |
+
from datasets import load_dataset, load_metric, Audio
|
11 |
+
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from typing import Dict, List, Union
|
14 |
+
from IPython.display import display, HTML
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class dataset_gen:
|
19 |
+
|
20 |
+
def __init__(self,processor):
|
21 |
+
self.processor = processor
|
22 |
+
|
23 |
+
def prepare_dataset(self,batch):
|
24 |
+
audio = batch["audio"]
|
25 |
+
|
26 |
+
batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
|
27 |
+
batch["input_length"] = len(batch["input_values"])
|
28 |
+
|
29 |
+
with self.processor.as_target_processor():
|
30 |
+
batch["labels"] = self.processor(batch["sentence"]).input_ids
|
31 |
+
return batch
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class DataCollatorCTCWithPadding:
|
35 |
+
|
36 |
+
processor: Wav2Vec2Processor
|
37 |
+
padding: Union[bool, str] = True
|
38 |
+
|
39 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
40 |
+
# split inputs and labels since they have to be of different lenghts and need
|
41 |
+
# different padding methods
|
42 |
+
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
43 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
44 |
+
|
45 |
+
batch = self.processor.pad(
|
46 |
+
input_features,
|
47 |
+
padding=self.padding,
|
48 |
+
return_tensors="pt",
|
49 |
+
)
|
50 |
+
with self.processor.as_target_processor():
|
51 |
+
labels_batch = self.processor.pad(
|
52 |
+
label_features,
|
53 |
+
padding=self.padding,
|
54 |
+
return_tensors="pt",
|
55 |
+
)
|
56 |
+
|
57 |
+
# replace padding with -100 to ignore loss correctly
|
58 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
59 |
+
|
60 |
+
batch["labels"] = labels
|
61 |
+
|
62 |
+
return batch
|
63 |
+
|
64 |
+
class metrics:
|
65 |
+
|
66 |
+
def __init__(self,processor,wer_metric):
|
67 |
+
self.processor = processor
|
68 |
+
self.wer_metric = wer_metric
|
69 |
+
|
70 |
+
def compute_metrics(self,pred):
|
71 |
+
pred_logits = pred.predictions
|
72 |
+
pred_ids = np.argmax(pred_logits, axis=-1)
|
73 |
+
|
74 |
+
pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id
|
75 |
+
|
76 |
+
pred_str = self.processor.batch_decode(pred_ids)
|
77 |
+
# we do not want to group tokens when computing the metrics
|
78 |
+
label_str = self.processor.batch_decode(pred.label_ids, group_tokens=False)
|
79 |
+
|
80 |
+
wer = self.wer_metric.compute(predictions=pred_str, references=label_str)
|
81 |
+
|
82 |
+
return {"wer": wer}
|
83 |
+
|
84 |
+
def show_random_elements(dataset, num_examples=10):
|
85 |
+
assert num_examples <= len(dataset)
|
86 |
+
picks = []
|
87 |
+
for _ in range(num_examples):
|
88 |
+
pick = random.randint(0, len(dataset)-1)
|
89 |
+
while pick in picks:
|
90 |
+
pick = random.randint(0, len(dataset)-1)
|
91 |
+
picks.append(pick)
|
92 |
+
|
93 |
+
df = pd.DataFrame(dataset[picks])
|
94 |
+
display(HTML(df.to_html()))
|
95 |
+
|
96 |
+
|
97 |
+
def remove_special_characters(batch):
|
98 |
+
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\&\/\d\_\\\]'
|
99 |
+
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
|
100 |
+
batch["sentence"] = re.sub('\u200c', '', batch["sentence"])
|
101 |
+
batch["sentence"] = re.sub('[a-z]', '', batch["sentence"])
|
102 |
+
return batch
|
103 |
+
|
104 |
+
|
105 |
+
def extract_all_chars(batch):
|
106 |
+
all_text = " ".join(batch["sentence"])
|
107 |
+
vocab = list(set(all_text))
|
108 |
+
return {"vocab": [vocab], "all_text": [all_text]}
|
109 |
+
|
110 |
+
|
111 |
+
def preprocess_labels(telugu_train,telugu_test):
|
112 |
+
|
113 |
+
|
114 |
+
telugu_train = telugu_train.map(remove_special_characters)
|
115 |
+
telugu_test = telugu_test.map(remove_special_characters)
|
116 |
+
|
117 |
+
|
118 |
+
vocab_train = telugu_train.map(extract_all_chars, batched=True, batch_size=-1,
|
119 |
+
keep_in_memory=True, remove_columns=telugu_train.column_names)
|
120 |
+
vocab_test = telugu_test.map(extract_all_chars, batched=True, batch_size=-1,
|
121 |
+
keep_in_memory=True, remove_columns=telugu_test.column_names)
|
122 |
+
|
123 |
+
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
|
124 |
+
|
125 |
+
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
|
126 |
+
vocab_dict["|"] = vocab_dict[" "]
|
127 |
+
del vocab_dict[" "]
|
128 |
+
|
129 |
+
vocab_dict["[UNK]"] = len(vocab_dict)
|
130 |
+
vocab_dict["[PAD]"] = len(vocab_dict)
|
131 |
+
|
132 |
+
with open('vocab.json', 'w') as vocab_file:
|
133 |
+
json.dump(vocab_dict, vocab_file)
|
134 |
+
|
135 |
+
return telugu_train, telugu_test
|
136 |
+
|
137 |
+
|
138 |
+
def preprocess_audio(telugu_train,telugu_test,processor):
|
139 |
+
|
140 |
+
telugu_train = telugu_train.cast_column("audio", Audio(sampling_rate=16_000))
|
141 |
+
telugu_test = telugu_test.cast_column("audio", Audio(sampling_rate=16_000))
|
142 |
+
|
143 |
+
dataset = dataset_gen(processor)
|
144 |
+
|
145 |
+
telugu_train = telugu_train.map(dataset.prepare_dataset, remove_columns=telugu_train.column_names)
|
146 |
+
telugu_test = telugu_test.map(dataset.prepare_dataset, remove_columns=telugu_test.column_names)
|
147 |
+
|
148 |
+
return telugu_train, telugu_test
|
149 |
+
|
150 |
+
def main(args):
|
151 |
+
|
152 |
+
repo_name = args.repo_name
|
153 |
+
telugu_dataset = load_dataset(args.dataset, args.config)
|
154 |
+
train_testvalid = telugu_dataset['train'].train_test_split(test_size=args.test_split_size)
|
155 |
+
telugu_train = train_testvalid["train"]
|
156 |
+
telugu_test = train_testvalid["test"]
|
157 |
+
|
158 |
+
telugu_train,telugu_test = preprocess_labels(telugu_train,telugu_test)
|
159 |
+
|
160 |
+
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
|
161 |
+
|
162 |
+
tokenizer.push_to_hub(repo_name)
|
163 |
+
|
164 |
+
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
|
165 |
+
|
166 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
167 |
+
|
168 |
+
telugu_train,telugu_test = preprocess_audio(telugu_train,telugu_test,processor)
|
169 |
+
|
170 |
+
|
171 |
+
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
172 |
+
|
173 |
+
wer_metric = load_metric("wer")
|
174 |
+
|
175 |
+
metric = metrics(processor,wer_metric)
|
176 |
+
|
177 |
+
model = Wav2Vec2ForCTC.from_pretrained(
|
178 |
+
args.model_id,
|
179 |
+
attention_dropout=0.0,
|
180 |
+
hidden_dropout=0.0,
|
181 |
+
feat_proj_dropout=0.0,
|
182 |
+
mask_time_prob=0.05,
|
183 |
+
layerdrop=0.0,
|
184 |
+
ctc_loss_reduction="mean",
|
185 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
186 |
+
vocab_size=len(processor.tokenizer),
|
187 |
+
)
|
188 |
+
|
189 |
+
model.freeze_feature_extractor()
|
190 |
+
|
191 |
+
training_args = TrainingArguments(
|
192 |
+
output_dir=repo_name,
|
193 |
+
group_by_length=True,
|
194 |
+
per_device_train_batch_size=16,
|
195 |
+
gradient_accumulation_steps=2,
|
196 |
+
evaluation_strategy="steps",
|
197 |
+
num_train_epochs=args.epochs,
|
198 |
+
gradient_checkpointing=True,
|
199 |
+
fp16=True,
|
200 |
+
save_steps=400,
|
201 |
+
eval_steps=400,
|
202 |
+
logging_steps=400,
|
203 |
+
learning_rate=3e-4,
|
204 |
+
warmup_steps=500,
|
205 |
+
save_total_limit=2,
|
206 |
+
push_to_hub=True,
|
207 |
+
)
|
208 |
+
|
209 |
+
trainer = Trainer(
|
210 |
+
model=model,
|
211 |
+
data_collator=data_collator,
|
212 |
+
args=training_args,
|
213 |
+
compute_metrics=metric.compute_metrics,
|
214 |
+
train_dataset=telugu_train,
|
215 |
+
eval_dataset=telugu_test,
|
216 |
+
tokenizer=processor.feature_extractor,
|
217 |
+
)
|
218 |
+
|
219 |
+
output = trainer.train()
|
220 |
+
print(output)
|
221 |
+
trainer.push_to_hub()
|
222 |
+
|
223 |
+
if __name__=="__main__":
|
224 |
+
parser = argparse.ArgumentParser()
|
225 |
+
|
226 |
+
parser.add_argument(
|
227 |
+
"--model_id", type=str, required=True, default="facebook/wav2vec2-large-xlsr-53", help="Model identifier. Should be loadable with 🤗 Transformers"
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--dataset", type=str, required=True, default="openslr", help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
|
231 |
+
)
|
232 |
+
parser.add_argument(
|
233 |
+
"--config", type=str, required=True, default="SLR66", help="Config of the dataset. *E.g.* `'en'` for Common Voice"
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--num_epochs", type=int, required =False, help="Number of epochs for training"
|
237 |
+
)
|
238 |
+
parser.add_argument(
|
239 |
+
"--repo_name", type=str, help="Name of the repo for storing files"
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--test_split_size", type= int, default=0.25, required=False, help="split size for test set from dataset"
|
243 |
+
)
|
244 |
+
|
245 |
+
args = parser.parse_args()
|
246 |
+
main(args)
|