Upload 2 files
Browse files- ctcalign.py +43 -41
- graph.py +2 -5
ctcalign.py
CHANGED
@@ -5,39 +5,43 @@ from dataclasses import dataclass
|
|
5 |
|
6 |
|
7 |
|
|
|
8 |
#convert frame-numbers to timestamps in seconds
|
9 |
# w2v2 step size is about 20ms, or 50 frames per second
|
10 |
def f2s(fr):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
def get_processor_labels(processor,word_sep,max_labels=100):
|
15 |
-
ixs = sorted(list(range(max_labels)),reverse=True)
|
16 |
-
return {processor.tokenizer.decode(n) or word_sep:n for n in ixs}
|
17 |
-
|
18 |
#------------------------------------------
|
19 |
# setup wav2vec2
|
20 |
#------------------------------------------
|
21 |
|
22 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
-
torch.random.manual_seed(0)
|
24 |
-
max_labels = 100 # any reasonable number higher than vocab + extra + special tokens in any language used
|
25 |
-
|
26 |
# important to know for CTC decoding - potentially language/model dependent
|
27 |
-
model_word_separator = '|'
|
28 |
-
model_blank_token = '[PAD]'
|
|
|
|
|
29 |
|
30 |
-
is_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
|
35 |
-
model = Wav2Vec2ForCTC.from_pretrained(is_MODEL_PATH).to(device)
|
36 |
-
processor = Wav2Vec2Processor.from_pretrained(is_MODEL_PATH)
|
37 |
-
labels_dict = get_processor_labels(processor,model_word_separator)
|
38 |
-
inverse_dict = {v:k for k,v in labels_dict.items()}
|
39 |
-
all_labels = tuple(labels_dict.keys())
|
40 |
-
blank_id = labels_dict[model_blank_token]
|
41 |
|
42 |
|
43 |
#------------------------------------------
|
@@ -49,11 +53,11 @@ blank_id = labels_dict[model_blank_token]
|
|
49 |
|
50 |
# return the label class probability of each audio frame
|
51 |
# wav is the wav data already read in, NOT the file path.
|
52 |
-
def get_frame_probs(wav):
|
53 |
with torch.inference_mode(): # similar to with torch.no_grad():
|
54 |
-
input_values = processor(wav,sampling_rate=16000).input_values[0]
|
55 |
-
input_values = torch.tensor(input_values, device=device).unsqueeze(0)
|
56 |
-
emits =
|
57 |
emits = torch.log_softmax(emits, dim=-1)
|
58 |
return emits[0].cpu().detach()
|
59 |
|
@@ -195,28 +199,26 @@ def basic(segs,wsep="|"):
|
|
195 |
|
196 |
# needs pad labels added to correctly time first segment
|
197 |
# and therefore add word sep character as placeholder in transcript
|
198 |
-
def prep_transcript(xcp):
|
199 |
-
xcp = xcp.replace(' ',model_word_separator)
|
200 |
-
label_ids = [labels_dict[c] for c in xcp]
|
201 |
-
label_ids = [blank_id] +
|
202 |
-
xcp = f'{model_word_separator}{xcp}{model_word_separator}'
|
203 |
return xcp,label_ids
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
211 |
|
212 |
segments = merge_repeats(path,norm_transcript)
|
213 |
-
words = merge_words(segments, model_word_separator)
|
214 |
|
215 |
#segments = [s for s in segments if s[0] != model_word_separator]
|
216 |
#return mfalike(segments,words,model_word_separator)
|
217 |
-
return basic(words,model_word_separator), basic(segments,model_word_separator)
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
|
222 |
|
|
|
5 |
|
6 |
|
7 |
|
8 |
+
|
9 |
#convert frame-numbers to timestamps in seconds
|
10 |
# w2v2 step size is about 20ms, or 50 frames per second
|
11 |
def f2s(fr):
|
12 |
+
return fr/50
|
13 |
+
|
14 |
+
|
|
|
|
|
|
|
|
|
15 |
#------------------------------------------
|
16 |
# setup wav2vec2
|
17 |
#------------------------------------------
|
18 |
|
|
|
|
|
|
|
|
|
19 |
# important to know for CTC decoding - potentially language/model dependent
|
20 |
+
#model_word_separator = '|'
|
21 |
+
#model_blank_token = '[PAD]'
|
22 |
+
#is_MODEL_PATH="../models/LVL/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
|
23 |
+
|
24 |
|
|
|
25 |
|
26 |
+
class CTCAligner:
|
27 |
+
|
28 |
+
def __init__(self, model_path,model_word_separator, model_blank_token):
|
29 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
torch.random.manual_seed(0)
|
31 |
+
|
32 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
|
33 |
+
self.processor = Wav2Vec2Processor.from_pretrained(model_path)
|
34 |
+
|
35 |
+
# build labels dict from a processor where it is not directly accessible
|
36 |
+
max_labels = 100 # any reasonable number higher than vocab + extra + special tokens in any language used
|
37 |
+
ixs = sorted(list(range(max_labels)),reverse=True)
|
38 |
+
self.labels_dict = {self.processor.tokenizer.decode(n) or model_word_separator:n for n in ixs}
|
39 |
+
|
40 |
+
self.blank_id = self.labels_dict[model_blank_token]
|
41 |
+
self.model_word_separator = model_word_separator
|
42 |
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
#------------------------------------------
|
|
|
53 |
|
54 |
# return the label class probability of each audio frame
|
55 |
# wav is the wav data already read in, NOT the file path.
|
56 |
+
def get_frame_probs(wav,aligner):
|
57 |
with torch.inference_mode(): # similar to with torch.no_grad():
|
58 |
+
input_values = aligner.processor(wav,sampling_rate=16000).input_values[0]
|
59 |
+
input_values = torch.tensor(input_values, device=aligner.device).unsqueeze(0)
|
60 |
+
emits = aligner.model(input_values).logits
|
61 |
emits = torch.log_softmax(emits, dim=-1)
|
62 |
return emits[0].cpu().detach()
|
63 |
|
|
|
199 |
|
200 |
# needs pad labels added to correctly time first segment
|
201 |
# and therefore add word sep character as placeholder in transcript
|
202 |
+
def prep_transcript(xcp, aligner):
|
203 |
+
xcp = xcp.replace(' ', aligner.model_word_separator)
|
204 |
+
label_ids = [aligner.labels_dict[c] for c in xcp]
|
205 |
+
label_ids = [aligner.blank_id] + label_ids + [aligner.blank_id]
|
206 |
+
xcp = f'{ aligner.model_word_separator}{xcp}{aligner.model_word_separator}'
|
207 |
return xcp,label_ids
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
def align(wav_data,transcript,aligner):
|
212 |
+
norm_transcript,rec_label_ids = prep_transcript(transcript, aligner)
|
213 |
+
emit = get_frame_probs(wav_data,aligner)
|
214 |
+
trellis = get_trellis(emit, rec_label_ids, aligner.blank_id)
|
215 |
+
path = backtrack(trellis, emit, rec_label_ids, aligner.blank_id)
|
216 |
|
217 |
segments = merge_repeats(path,norm_transcript)
|
218 |
+
words = merge_words(segments, aligner.model_word_separator)
|
219 |
|
220 |
#segments = [s for s in segments if s[0] != model_word_separator]
|
221 |
#return mfalike(segments,words,model_word_separator)
|
222 |
+
return basic(words,aligner.model_word_separator), basic(segments,aligner.model_word_separator)
|
|
|
|
|
|
|
223 |
|
224 |
|
graph.py
CHANGED
@@ -41,16 +41,15 @@ def get_pitch_tracks(wav_path):
|
|
41 |
|
42 |
|
43 |
|
44 |
-
|
45 |
# transcript could be from a corpus with the wav file,
|
46 |
# input by the user,
|
47 |
# or from a previous speech recognition process
|
48 |
-
def align_and_graph(wav_path, transcript):
|
49 |
|
50 |
# fetch data
|
51 |
#f0_data = get_pitch_tracks(wav_path)
|
52 |
speech = readwav(wav_path)
|
53 |
-
w_align, seg_align = ctcalign.align(speech,normalise_transcript(transcript))
|
54 |
|
55 |
|
56 |
# set up the graph shape
|
@@ -113,5 +112,3 @@ def align_and_graph(wav_path, transcript):
|
|
113 |
|
114 |
|
115 |
# uppboðssøla bussleiðini viðmerkingar upprunaligur
|
116 |
-
|
117 |
-
|
|
|
41 |
|
42 |
|
43 |
|
|
|
44 |
# transcript could be from a corpus with the wav file,
|
45 |
# input by the user,
|
46 |
# or from a previous speech recognition process
|
47 |
+
def align_and_graph(wav_path, transcript,lang_aligner):
|
48 |
|
49 |
# fetch data
|
50 |
#f0_data = get_pitch_tracks(wav_path)
|
51 |
speech = readwav(wav_path)
|
52 |
+
w_align, seg_align = ctcalign.align(speech,normalise_transcript(transcript),lang_aligner)
|
53 |
|
54 |
|
55 |
# set up the graph shape
|
|
|
112 |
|
113 |
|
114 |
# uppboðssøla bussleiðini viðmerkingar upprunaligur
|
|
|
|