krishnateja commited on
Commit
e5deaf1
·
1 Parent(s): afab1b0

training script

Browse files
Files changed (1) hide show
  1. telugu_xlmr.py +246 -0
telugu_xlmr.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import pandas as pd
4
+ import re
5
+ import json
6
+ import torch
7
+ import argparse
8
+
9
+
10
+ from datasets import load_dataset, load_metric, Audio
11
+ from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
12
+ from dataclasses import dataclass, field
13
+ from typing import Dict, List, Union
14
+ from IPython.display import display, HTML
15
+
16
+
17
+
18
+ class dataset_gen:
19
+
20
+ def __init__(self,processor):
21
+ self.processor = processor
22
+
23
+ def prepare_dataset(self,batch):
24
+ audio = batch["audio"]
25
+
26
+ batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
27
+ batch["input_length"] = len(batch["input_values"])
28
+
29
+ with self.processor.as_target_processor():
30
+ batch["labels"] = self.processor(batch["sentence"]).input_ids
31
+ return batch
32
+
33
+ @dataclass
34
+ class DataCollatorCTCWithPadding:
35
+
36
+ processor: Wav2Vec2Processor
37
+ padding: Union[bool, str] = True
38
+
39
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
40
+ # split inputs and labels since they have to be of different lenghts and need
41
+ # different padding methods
42
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
43
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
44
+
45
+ batch = self.processor.pad(
46
+ input_features,
47
+ padding=self.padding,
48
+ return_tensors="pt",
49
+ )
50
+ with self.processor.as_target_processor():
51
+ labels_batch = self.processor.pad(
52
+ label_features,
53
+ padding=self.padding,
54
+ return_tensors="pt",
55
+ )
56
+
57
+ # replace padding with -100 to ignore loss correctly
58
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
59
+
60
+ batch["labels"] = labels
61
+
62
+ return batch
63
+
64
+ class metrics:
65
+
66
+ def __init__(self,processor,wer_metric):
67
+ self.processor = processor
68
+ self.wer_metric = wer_metric
69
+
70
+ def compute_metrics(self,pred):
71
+ pred_logits = pred.predictions
72
+ pred_ids = np.argmax(pred_logits, axis=-1)
73
+
74
+ pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id
75
+
76
+ pred_str = self.processor.batch_decode(pred_ids)
77
+ # we do not want to group tokens when computing the metrics
78
+ label_str = self.processor.batch_decode(pred.label_ids, group_tokens=False)
79
+
80
+ wer = self.wer_metric.compute(predictions=pred_str, references=label_str)
81
+
82
+ return {"wer": wer}
83
+
84
+ def show_random_elements(dataset, num_examples=10):
85
+ assert num_examples <= len(dataset)
86
+ picks = []
87
+ for _ in range(num_examples):
88
+ pick = random.randint(0, len(dataset)-1)
89
+ while pick in picks:
90
+ pick = random.randint(0, len(dataset)-1)
91
+ picks.append(pick)
92
+
93
+ df = pd.DataFrame(dataset[picks])
94
+ display(HTML(df.to_html()))
95
+
96
+
97
+ def remove_special_characters(batch):
98
+ chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\&\/\d\_\\\]'
99
+ batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
100
+ batch["sentence"] = re.sub('\u200c', '', batch["sentence"])
101
+ batch["sentence"] = re.sub('[a-z]', '', batch["sentence"])
102
+ return batch
103
+
104
+
105
+ def extract_all_chars(batch):
106
+ all_text = " ".join(batch["sentence"])
107
+ vocab = list(set(all_text))
108
+ return {"vocab": [vocab], "all_text": [all_text]}
109
+
110
+
111
+ def preprocess_labels(telugu_train,telugu_test):
112
+
113
+
114
+ telugu_train = telugu_train.map(remove_special_characters)
115
+ telugu_test = telugu_test.map(remove_special_characters)
116
+
117
+
118
+ vocab_train = telugu_train.map(extract_all_chars, batched=True, batch_size=-1,
119
+ keep_in_memory=True, remove_columns=telugu_train.column_names)
120
+ vocab_test = telugu_test.map(extract_all_chars, batched=True, batch_size=-1,
121
+ keep_in_memory=True, remove_columns=telugu_test.column_names)
122
+
123
+ vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
124
+
125
+ vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
126
+ vocab_dict["|"] = vocab_dict[" "]
127
+ del vocab_dict[" "]
128
+
129
+ vocab_dict["[UNK]"] = len(vocab_dict)
130
+ vocab_dict["[PAD]"] = len(vocab_dict)
131
+
132
+ with open('vocab.json', 'w') as vocab_file:
133
+ json.dump(vocab_dict, vocab_file)
134
+
135
+ return telugu_train, telugu_test
136
+
137
+
138
+ def preprocess_audio(telugu_train,telugu_test,processor):
139
+
140
+ telugu_train = telugu_train.cast_column("audio", Audio(sampling_rate=16_000))
141
+ telugu_test = telugu_test.cast_column("audio", Audio(sampling_rate=16_000))
142
+
143
+ dataset = dataset_gen(processor)
144
+
145
+ telugu_train = telugu_train.map(dataset.prepare_dataset, remove_columns=telugu_train.column_names)
146
+ telugu_test = telugu_test.map(dataset.prepare_dataset, remove_columns=telugu_test.column_names)
147
+
148
+ return telugu_train, telugu_test
149
+
150
+ def main(args):
151
+
152
+ repo_name = args.repo_name
153
+ telugu_dataset = load_dataset(args.dataset, args.config)
154
+ train_testvalid = telugu_dataset['train'].train_test_split(test_size=args.test_split_size)
155
+ telugu_train = train_testvalid["train"]
156
+ telugu_test = train_testvalid["test"]
157
+
158
+ telugu_train,telugu_test = preprocess_labels(telugu_train,telugu_test)
159
+
160
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
161
+
162
+ tokenizer.push_to_hub(repo_name)
163
+
164
+ feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
165
+
166
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
167
+
168
+ telugu_train,telugu_test = preprocess_audio(telugu_train,telugu_test,processor)
169
+
170
+
171
+ data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
172
+
173
+ wer_metric = load_metric("wer")
174
+
175
+ metric = metrics(processor,wer_metric)
176
+
177
+ model = Wav2Vec2ForCTC.from_pretrained(
178
+ args.model_id,
179
+ attention_dropout=0.0,
180
+ hidden_dropout=0.0,
181
+ feat_proj_dropout=0.0,
182
+ mask_time_prob=0.05,
183
+ layerdrop=0.0,
184
+ ctc_loss_reduction="mean",
185
+ pad_token_id=processor.tokenizer.pad_token_id,
186
+ vocab_size=len(processor.tokenizer),
187
+ )
188
+
189
+ model.freeze_feature_extractor()
190
+
191
+ training_args = TrainingArguments(
192
+ output_dir=repo_name,
193
+ group_by_length=True,
194
+ per_device_train_batch_size=16,
195
+ gradient_accumulation_steps=2,
196
+ evaluation_strategy="steps",
197
+ num_train_epochs=args.epochs,
198
+ gradient_checkpointing=True,
199
+ fp16=True,
200
+ save_steps=400,
201
+ eval_steps=400,
202
+ logging_steps=400,
203
+ learning_rate=3e-4,
204
+ warmup_steps=500,
205
+ save_total_limit=2,
206
+ push_to_hub=True,
207
+ )
208
+
209
+ trainer = Trainer(
210
+ model=model,
211
+ data_collator=data_collator,
212
+ args=training_args,
213
+ compute_metrics=metric.compute_metrics,
214
+ train_dataset=telugu_train,
215
+ eval_dataset=telugu_test,
216
+ tokenizer=processor.feature_extractor,
217
+ )
218
+
219
+ output = trainer.train()
220
+ print(output)
221
+ trainer.push_to_hub()
222
+
223
+ if __name__=="__main__":
224
+ parser = argparse.ArgumentParser()
225
+
226
+ parser.add_argument(
227
+ "--model_id", type=str, required=True, default="facebook/wav2vec2-large-xlsr-53", help="Model identifier. Should be loadable with 🤗 Transformers"
228
+ )
229
+ parser.add_argument(
230
+ "--dataset", type=str, required=True, default="openslr", help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
231
+ )
232
+ parser.add_argument(
233
+ "--config", type=str, required=True, default="SLR66", help="Config of the dataset. *E.g.* `'en'` for Common Voice"
234
+ )
235
+ parser.add_argument(
236
+ "--num_epochs", type=int, required =False, help="Number of epochs for training"
237
+ )
238
+ parser.add_argument(
239
+ "--repo_name", type=str, help="Name of the repo for storing files"
240
+ )
241
+ parser.add_argument(
242
+ "--test_split_size", type= int, default=0.25, required=False, help="split size for test set from dataset"
243
+ )
244
+
245
+ args = parser.parse_args()
246
+ main(args)