clr commited on
Commit
459923a
·
1 Parent(s): 7116577

Upload 2 files

Browse files
Files changed (2) hide show
  1. ctcalign.py +222 -0
  2. graph.py +117 -0
ctcalign.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
2
+ import torch
3
+ import numpy as np
4
+ from dataclasses import dataclass
5
+
6
+
7
+
8
+ #convert frame-numbers to timestamps in seconds
9
+ # w2v2 step size is about 20ms, or 50 frames per second
10
+ def f2s(fr):
11
+ return fr/50
12
+
13
+ # build labels dict from a processor where it is not directly accessible
14
+ def get_processor_labels(processor,word_sep,max_labels=100):
15
+ ixs = sorted(list(range(max_labels)),reverse=True)
16
+ return {processor.tokenizer.decode(n) or word_sep:n for n in ixs}
17
+
18
+ #------------------------------------------
19
+ # setup wav2vec2
20
+ #------------------------------------------
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ torch.random.manual_seed(0)
24
+ max_labels = 100 # any reasonable number higher than vocab + extra + special tokens in any language used
25
+
26
+ # important to know for CTC decoding - potentially language/model dependent
27
+ model_word_separator = '|'
28
+ model_blank_token = '[PAD]'
29
+
30
+ is_MODEL_PATH="../models/LVL/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
31
+
32
+
33
+
34
+
35
+ model = Wav2Vec2ForCTC.from_pretrained(is_MODEL_PATH).to(device)
36
+ processor = Wav2Vec2Processor.from_pretrained(is_MODEL_PATH)
37
+ labels_dict = get_processor_labels(processor,model_word_separator)
38
+ inverse_dict = {v:k for k,v in labels_dict.items()}
39
+ all_labels = tuple(labels_dict.keys())
40
+ blank_id = labels_dict[model_blank_token]
41
+
42
+
43
+ #------------------------------------------
44
+ # forced alignment with ctc decoder
45
+ # based on implementation of
46
+ # https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html
47
+ #------------------------------------------
48
+
49
+
50
+ # return the label class probability of each audio frame
51
+ # wav is the wav data already read in, NOT the file path.
52
+ def get_frame_probs(wav):
53
+ with torch.inference_mode(): # similar to with torch.no_grad():
54
+ input_values = processor(wav,sampling_rate=16000).input_values[0]
55
+ input_values = torch.tensor(input_values, device=device).unsqueeze(0)
56
+ emits = model(input_values).logits
57
+ emits = torch.log_softmax(emits, dim=-1)
58
+ return emits[0].cpu().detach()
59
+
60
+
61
+ def get_trellis(emission, tokens, blank_id):
62
+
63
+ num_frame = emission.size(0)
64
+ num_tokens = len(tokens)
65
+ trellis = torch.empty((num_frame + 1, num_tokens + 1))
66
+ trellis[0, 0] = 0
67
+ trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) # len of this slice of trellis is len of audio frames)
68
+ trellis[0, -num_tokens:] = -float("inf") # len of this slice of trellis is len of transcript tokens
69
+ trellis[-num_tokens:, 0] = float("inf")
70
+ for t in range(num_frame):
71
+ trellis[t + 1, 1:] = torch.maximum(
72
+ # Score for staying at the same token
73
+ trellis[t, 1:] + emission[t, blank_id],
74
+ # Score for changing to the next token
75
+ trellis[t, :-1] + emission[t, tokens],
76
+ )
77
+ return trellis
78
+
79
+
80
+
81
+ @dataclass
82
+ class Point:
83
+ token_index: int
84
+ time_index: int
85
+ score: float
86
+
87
+ @dataclass
88
+ class Segment:
89
+ label: str
90
+ start: int
91
+ end: int
92
+ score: float
93
+
94
+ @property
95
+ def mfaform(self):
96
+ return f"{f2s(self.start)},{f2s(self.end)},{self.label}"
97
+
98
+ @property
99
+ def length(self):
100
+ return self.end - self.start
101
+
102
+
103
+
104
+ def backtrack(trellis, emission, tokens, blank_id):
105
+ # Note:
106
+ # j and t are indices for trellis, which has extra dimensions
107
+ # for time and tokens at the beginning.
108
+ # When referring to time frame index `T` in trellis,
109
+ # the corresponding index in emission is `T-1`.
110
+ # Similarly, when referring to token index `J` in trellis,
111
+ # the corresponding index in transcript is `J-1`.
112
+ j = trellis.size(1) - 1
113
+ t_start = torch.argmax(trellis[:, j]).item()
114
+
115
+ path = []
116
+ for t in range(t_start, 0, -1):
117
+ # 1. Figure out if the current position was stay or change
118
+ # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
119
+ # Score for token staying the same from time frame J-1 to T.
120
+ stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
121
+ # Score for token changing from C-1 at T-1 to J at T.
122
+ changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
123
+
124
+ # 2. Store the path with frame-wise probability.
125
+ prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
126
+ # Return token index and time index in non-trellis coordinate.
127
+ path.append(Point(j - 1, t - 1, prob))
128
+
129
+ # 3. Update the token
130
+ if changed > stayed:
131
+ j -= 1
132
+ if j == 0:
133
+ break
134
+ else:
135
+ raise ValueError("Failed to align")
136
+ return path[::-1]
137
+
138
+
139
+ def merge_repeats(path,transcript):
140
+ i1, i2 = 0, 0
141
+ segments = []
142
+ while i1 < len(path):
143
+ while i2 < len(path) and path[i1].token_index == path[i2].token_index: # while both path steps point to the same token index
144
+ i2 += 1
145
+ score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
146
+ segments.append( # when i2 finally switches to a different token,
147
+ Segment(
148
+ transcript[path[i1].token_index],# to the list of segments, append the token from i1
149
+ path[i1].time_index, # time of the first path-point of that token
150
+ path[i2 - 1].time_index + 1, # time of the final path-point for that token.
151
+ score,
152
+ )
153
+ )
154
+ i1 = i2
155
+ return segments
156
+
157
+
158
+
159
+ def merge_words(segments, separator):
160
+ words = []
161
+ i1, i2 = 0, 0
162
+ while i1 < len(segments):
163
+ if i2 >= len(segments) or segments[i2].label == separator:
164
+ if i1 != i2:
165
+ segs = segments[i1:i2]
166
+ word = "".join([seg.label for seg in segs])
167
+ score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
168
+ words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
169
+ i1 = i2 + 1
170
+ i2 = i1
171
+ else:
172
+ i2 += 1
173
+ return words
174
+
175
+
176
+
177
+ #------------------------------------------
178
+ # handle etc.
179
+ #------------------------------------------
180
+
181
+
182
+ # generate mfa format for character (phone) and word alignments
183
+ # skip the word separator as it is not a phone
184
+ def mfalike(chars,wds,wsep):
185
+ hed = ['Begin,End,Label,Type,Speaker\n']
186
+ wlines = [f'{w.mfaform},words,000\n' for w in wds]
187
+ slines = [f'{ch.mfaform},phones,000\n' for ch in chars if ch.label != wsep]
188
+ return (''.join(hed+wlines+slines))
189
+
190
+ # generate basic exportable list format for character OR word alignments
191
+ # skip the word separator as it is not a phone
192
+ def basic(segs,wsep="|"):
193
+ return [[s.label,f2s(s.start),f2s(s.end)] for s in segs if s.label != wsep]
194
+
195
+
196
+ # needs pad labels added to correctly time first segment
197
+ # and therefore add word sep character as placeholder in transcript
198
+ def prep_transcript(xcp):
199
+ xcp = xcp.replace(' ',model_word_separator)
200
+ label_ids = [labels_dict[c] for c in xcp]
201
+ label_ids = [blank_id] + label_ids + [blank_id]
202
+ xcp = f'{model_word_separator}{xcp}{model_word_separator}'
203
+ return xcp,label_ids
204
+
205
+
206
+ def align(wav_data,transcript):
207
+ norm_transcript,rec_label_ids = prep_transcript(transcript)
208
+ emit = get_frame_probs(wav_data)
209
+ trellis = get_trellis(emit, rec_label_ids, blank_id)
210
+ path = backtrack(trellis, emit, rec_label_ids, blank_id)
211
+
212
+ segments = merge_repeats(path,norm_transcript)
213
+ words = merge_words(segments, model_word_separator)
214
+
215
+ #segments = [s for s in segments if s[0] != model_word_separator]
216
+ #return mfalike(segments,words,model_word_separator)
217
+ return basic(words,model_word_separator), basic(segments,model_word_separator)
218
+
219
+
220
+
221
+
222
+
graph.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import soundfile as sf
3
+ from scipy import signal
4
+ import librosa
5
+ import subprocess
6
+ import matplotlib.pyplot as plt
7
+ import ctcalign
8
+
9
+
10
+
11
+ def readwav(wav_path):
12
+ wav, sr = sf.read(wav_path, dtype=np.float32)
13
+ if len(wav.shape) == 2:
14
+ wav = wav.mean(1)
15
+ if sr != 16000:
16
+ wlen = int(wav.shape[0] / sr * 16000)
17
+ wav = signal.resample(wav, wlen)
18
+ return wav
19
+
20
+
21
+ def normalise_transcript(xcp):
22
+ xcp = xcp.lower()
23
+ while ' ' in xcp:
24
+ xcp = xcp.replace(' ', ' ')
25
+ return xcp
26
+
27
+
28
+
29
+ def get_pitch_tracks(wav_path):
30
+ f0_data = subprocess.run(["REAPER/build/reaper", "-i", wav_path, '-a']).stdout
31
+ #with open('tmp.f0','r') as handle:
32
+ f0_data = f0_data.split('EST_Header_End\n')[1].splitlines()
33
+ print(f0_data) #!!!!!!!!!!!!!!!!!!!!!
34
+ f0_data = [l.split(' ') for l in f0_data]
35
+ f0_data = [ [float(t), float(f)] for t,v,f in f0_data if v=='1']
36
+ f0_data = [[t,f0] for t,prob,f0 in f0_data if prob==1.0]
37
+ return f0_data
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+ # transcript could be from a corpus with the wav file,
46
+ # input by the user,
47
+ # or from a previous speech recognition process
48
+ def align_and_graph(wav_path, transcript):
49
+
50
+ # fetch data
51
+ #f0_data = get_pitch_tracks(wav_path)
52
+ speech = readwav(wav_path)
53
+ w_align, seg_align = ctcalign.align(speech,normalise_transcript(transcript))
54
+
55
+
56
+ # set up the graph shape
57
+ rec_start = w_align[0][1]
58
+ rec_end = w_align[-1][2]
59
+
60
+ f0_data = get_pitch_tracks(wav_path)
61
+ if f0_data:
62
+ f_max = max([f0 for t,f0 in f0_data]) + 50
63
+ else:
64
+ f_max = 400
65
+
66
+
67
+ fig, axes1 = plt.subplots(figsize=(15,5))
68
+ plt.xlim([rec_start, rec_end])
69
+ axes1.set_ylim([0.0, f_max])
70
+ axes1.get_xaxis().set_visible(False)
71
+
72
+ # draw word boundaries
73
+ for w,s,e in w_align:
74
+ plt.vlines(s,0,f_max,linewidth=0.5,color='black')
75
+ plt.vlines(e,0,f_max,linewidth=0.5,color='black')
76
+ plt.text( (s+e)/2 - (len(w)*.01), f_max+15, w, fontsize=15)
77
+
78
+ # draw phone/char boundaries
79
+ for p,s,e in seg_align:
80
+ plt.vlines(s,0,f_max,linewidth=0.3,color='cadetblue',linestyle=(0,(10,4)))
81
+ plt.vlines(e,0,f_max,linewidth=0.3,color='cadetblue',linestyle=(0,(10,4)))
82
+ plt.text( (s+e)/2 - (len(p)*.01), -30, p, fontsize=15, color='teal')
83
+
84
+
85
+ f0c = "blue"
86
+ axes1.scatter([t for t,f0 in f0_data], [f0 for t,f0 in f0_data], color=f0c)
87
+
88
+
89
+
90
+ w, sr = librosa.load(wav_path)
91
+ fr_l = 2048 # librosa default
92
+ h_l = 512 # default
93
+ rmse = librosa.feature.rms(y=w, frame_length = fr_l, hop_length = h_l)
94
+ rmse = rmse[0]
95
+
96
+
97
+ # show rms energy
98
+ axes2 = axes1.twinx()
99
+ axes2.set_ylim([0.0, 0.5])
100
+ rms_xval = [(h_l*i)/sr for i in range(len(rmse))]
101
+ axes2.plot(rms_xval,rmse,color='peachpuff',linewidth=3.5)
102
+
103
+
104
+ # label the graph
105
+ axes1.set_ylabel("Pitch (F0, Hz)", fontsize=14, color="blue")
106
+ axes2.set_ylabel("RMS energy", fontsize=14,color="coral")
107
+ #plt.title(f'Recording {file_id} (L1 {language_dict[file_id]})', fontsize=15)
108
+ #plt.show()
109
+
110
+ return fig
111
+
112
+ #plt.close('all')
113
+
114
+
115
+ # uppboðssøla bussleiðini viðmerkingar upprunaligur
116
+
117
+