clr commited on
Commit
7e57864
·
1 Parent(s): 331a033

Upload 2 files

Browse files
Files changed (2) hide show
  1. ctcalign.py +43 -41
  2. graph.py +2 -5
ctcalign.py CHANGED
@@ -5,39 +5,43 @@ from dataclasses import dataclass
5
 
6
 
7
 
 
8
  #convert frame-numbers to timestamps in seconds
9
  # w2v2 step size is about 20ms, or 50 frames per second
10
  def f2s(fr):
11
- return fr/50
12
-
13
- # build labels dict from a processor where it is not directly accessible
14
- def get_processor_labels(processor,word_sep,max_labels=100):
15
- ixs = sorted(list(range(max_labels)),reverse=True)
16
- return {processor.tokenizer.decode(n) or word_sep:n for n in ixs}
17
-
18
  #------------------------------------------
19
  # setup wav2vec2
20
  #------------------------------------------
21
 
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- torch.random.manual_seed(0)
24
- max_labels = 100 # any reasonable number higher than vocab + extra + special tokens in any language used
25
-
26
  # important to know for CTC decoding - potentially language/model dependent
27
- model_word_separator = '|'
28
- model_blank_token = '[PAD]'
 
 
29
 
30
- is_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
 
35
- model = Wav2Vec2ForCTC.from_pretrained(is_MODEL_PATH).to(device)
36
- processor = Wav2Vec2Processor.from_pretrained(is_MODEL_PATH)
37
- labels_dict = get_processor_labels(processor,model_word_separator)
38
- inverse_dict = {v:k for k,v in labels_dict.items()}
39
- all_labels = tuple(labels_dict.keys())
40
- blank_id = labels_dict[model_blank_token]
41
 
42
 
43
  #------------------------------------------
@@ -49,11 +53,11 @@ blank_id = labels_dict[model_blank_token]
49
 
50
  # return the label class probability of each audio frame
51
  # wav is the wav data already read in, NOT the file path.
52
- def get_frame_probs(wav):
53
  with torch.inference_mode(): # similar to with torch.no_grad():
54
- input_values = processor(wav,sampling_rate=16000).input_values[0]
55
- input_values = torch.tensor(input_values, device=device).unsqueeze(0)
56
- emits = model(input_values).logits
57
  emits = torch.log_softmax(emits, dim=-1)
58
  return emits[0].cpu().detach()
59
 
@@ -195,28 +199,26 @@ def basic(segs,wsep="|"):
195
 
196
  # needs pad labels added to correctly time first segment
197
  # and therefore add word sep character as placeholder in transcript
198
- def prep_transcript(xcp):
199
- xcp = xcp.replace(' ',model_word_separator)
200
- label_ids = [labels_dict[c] for c in xcp]
201
- label_ids = [blank_id] + label_ids + [blank_id]
202
- xcp = f'{model_word_separator}{xcp}{model_word_separator}'
203
  return xcp,label_ids
204
-
205
-
206
- def align(wav_data,transcript):
207
- norm_transcript,rec_label_ids = prep_transcript(transcript)
208
- emit = get_frame_probs(wav_data)
209
- trellis = get_trellis(emit, rec_label_ids, blank_id)
210
- path = backtrack(trellis, emit, rec_label_ids, blank_id)
 
211
 
212
  segments = merge_repeats(path,norm_transcript)
213
- words = merge_words(segments, model_word_separator)
214
 
215
  #segments = [s for s in segments if s[0] != model_word_separator]
216
  #return mfalike(segments,words,model_word_separator)
217
- return basic(words,model_word_separator), basic(segments,model_word_separator)
218
-
219
-
220
-
221
 
222
 
 
5
 
6
 
7
 
8
+
9
  #convert frame-numbers to timestamps in seconds
10
  # w2v2 step size is about 20ms, or 50 frames per second
11
  def f2s(fr):
12
+ return fr/50
13
+
14
+
 
 
 
 
15
  #------------------------------------------
16
  # setup wav2vec2
17
  #------------------------------------------
18
 
 
 
 
 
19
  # important to know for CTC decoding - potentially language/model dependent
20
+ #model_word_separator = '|'
21
+ #model_blank_token = '[PAD]'
22
+ #is_MODEL_PATH="../models/LVL/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
23
+
24
 
 
25
 
26
+ class CTCAligner:
27
+
28
+ def __init__(self, model_path,model_word_separator, model_blank_token):
29
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ torch.random.manual_seed(0)
31
+
32
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
33
+ self.processor = Wav2Vec2Processor.from_pretrained(model_path)
34
+
35
+ # build labels dict from a processor where it is not directly accessible
36
+ max_labels = 100 # any reasonable number higher than vocab + extra + special tokens in any language used
37
+ ixs = sorted(list(range(max_labels)),reverse=True)
38
+ self.labels_dict = {self.processor.tokenizer.decode(n) or model_word_separator:n for n in ixs}
39
+
40
+ self.blank_id = self.labels_dict[model_blank_token]
41
+ self.model_word_separator = model_word_separator
42
 
43
 
44
 
 
 
 
 
 
 
45
 
46
 
47
  #------------------------------------------
 
53
 
54
  # return the label class probability of each audio frame
55
  # wav is the wav data already read in, NOT the file path.
56
+ def get_frame_probs(wav,aligner):
57
  with torch.inference_mode(): # similar to with torch.no_grad():
58
+ input_values = aligner.processor(wav,sampling_rate=16000).input_values[0]
59
+ input_values = torch.tensor(input_values, device=aligner.device).unsqueeze(0)
60
+ emits = aligner.model(input_values).logits
61
  emits = torch.log_softmax(emits, dim=-1)
62
  return emits[0].cpu().detach()
63
 
 
199
 
200
  # needs pad labels added to correctly time first segment
201
  # and therefore add word sep character as placeholder in transcript
202
+ def prep_transcript(xcp, aligner):
203
+ xcp = xcp.replace(' ', aligner.model_word_separator)
204
+ label_ids = [aligner.labels_dict[c] for c in xcp]
205
+ label_ids = [aligner.blank_id] + label_ids + [aligner.blank_id]
206
+ xcp = f'{ aligner.model_word_separator}{xcp}{aligner.model_word_separator}'
207
  return xcp,label_ids
208
+
209
+
210
+
211
+ def align(wav_data,transcript,aligner):
212
+ norm_transcript,rec_label_ids = prep_transcript(transcript, aligner)
213
+ emit = get_frame_probs(wav_data,aligner)
214
+ trellis = get_trellis(emit, rec_label_ids, aligner.blank_id)
215
+ path = backtrack(trellis, emit, rec_label_ids, aligner.blank_id)
216
 
217
  segments = merge_repeats(path,norm_transcript)
218
+ words = merge_words(segments, aligner.model_word_separator)
219
 
220
  #segments = [s for s in segments if s[0] != model_word_separator]
221
  #return mfalike(segments,words,model_word_separator)
222
+ return basic(words,aligner.model_word_separator), basic(segments,aligner.model_word_separator)
 
 
 
223
 
224
 
graph.py CHANGED
@@ -41,16 +41,15 @@ def get_pitch_tracks(wav_path):
41
 
42
 
43
 
44
-
45
  # transcript could be from a corpus with the wav file,
46
  # input by the user,
47
  # or from a previous speech recognition process
48
- def align_and_graph(wav_path, transcript):
49
 
50
  # fetch data
51
  #f0_data = get_pitch_tracks(wav_path)
52
  speech = readwav(wav_path)
53
- w_align, seg_align = ctcalign.align(speech,normalise_transcript(transcript))
54
 
55
 
56
  # set up the graph shape
@@ -113,5 +112,3 @@ def align_and_graph(wav_path, transcript):
113
 
114
 
115
  # uppboðssøla bussleiðini viðmerkingar upprunaligur
116
-
117
-
 
41
 
42
 
43
 
 
44
  # transcript could be from a corpus with the wav file,
45
  # input by the user,
46
  # or from a previous speech recognition process
47
+ def align_and_graph(wav_path, transcript,lang_aligner):
48
 
49
  # fetch data
50
  #f0_data = get_pitch_tracks(wav_path)
51
  speech = readwav(wav_path)
52
+ w_align, seg_align = ctcalign.align(speech,normalise_transcript(transcript),lang_aligner)
53
 
54
 
55
  # set up the graph shape
 
112
 
113
 
114
  # uppboðssøla bussleiðini viðmerkingar upprunaligur