Dupaja commited on
Commit
c5c6476
·
1 Parent(s): fc3317e

Rewrite to handle larger text

Browse files
Files changed (1) hide show
  1. handler.py +81 -11
handler.py CHANGED
@@ -4,9 +4,80 @@ import torch
4
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
5
  from datasets import load_dataset
6
  import time
 
7
  from typing import Dict, List, Any
8
 
9
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
@@ -37,27 +108,26 @@ class EndpointHandler:
37
 
38
  start_time = time.time()
39
 
40
- inputs = self.processor(text=given_text, return_tensors="pt")
41
-
42
- run_time_processor = time.time() - start_time
 
 
 
43
 
44
- start_time_speech = time.time()
45
 
46
- speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
47
 
48
- run_time_speech = time.time() - start_time_speech
49
 
50
  run_time_total = time.time() - start_time
51
-
52
 
53
  # Return the expected response format
54
  return {
55
  "statusCode": 200,
56
  "body": {
57
- "audio": speech.numpy(), # Consider encoding this to a suitable format
58
  "sampling_rate": 16000,
59
- "run_time_processor": str(run_time_processor),
60
- "run_time_speech": str(run_time_speech),
61
  "run_time_total": str(run_time_total),
62
  }
63
  }
 
4
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
5
  from datasets import load_dataset
6
  import time
7
+ import re
8
  from typing import Dict, List, Any
9
 
10
+ #from tourtise utils
11
+ def split_and_recombine_text(text, desired_length=200, max_length=300):
12
+ """Split text it into chunks of a desired length trying to keep sentences intact."""
13
+ # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
14
+ text = re.sub(r'\n\n+', '\n', text)
15
+ text = re.sub(r'\s+', ' ', text)
16
+ text = re.sub(r'[“”]', '"', text)
17
+
18
+ rv = []
19
+ in_quote = False
20
+ current = ""
21
+ split_pos = []
22
+ pos = -1
23
+ end_pos = len(text) - 1
24
+
25
+ def seek(delta):
26
+ nonlocal pos, in_quote, current
27
+ is_neg = delta < 0
28
+ for _ in range(abs(delta)):
29
+ if is_neg:
30
+ pos -= 1
31
+ current = current[:-1]
32
+ else:
33
+ pos += 1
34
+ current += text[pos]
35
+ if text[pos] == '"':
36
+ in_quote = not in_quote
37
+ return text[pos]
38
+
39
+ def peek(delta):
40
+ p = pos + delta
41
+ return text[p] if p < end_pos and p >= 0 else ""
42
+
43
+ def commit():
44
+ nonlocal rv, current, split_pos
45
+ rv.append(current)
46
+ current = ""
47
+ split_pos = []
48
+
49
+ while pos < end_pos:
50
+ c = seek(1)
51
+ # do we need to force a split?
52
+ if len(current) >= max_length:
53
+ if len(split_pos) > 0 and len(current) > (desired_length / 2):
54
+ # we have at least one sentence and we are over half the desired length, seek back to the last split
55
+ d = pos - split_pos[-1]
56
+ seek(-d)
57
+ else:
58
+ # no full sentences, seek back until we are not in the middle of a word and split there
59
+ while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
60
+ c = seek(-1)
61
+ commit()
62
+ # check for sentence boundaries
63
+ elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')):
64
+ # seek forward if we have consecutive boundary markers but still within the max length
65
+ while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
66
+ c = seek(1)
67
+ split_pos.append(pos)
68
+ if len(current) >= desired_length:
69
+ commit()
70
+ # treat end of quote as a boundary if its followed by a space or newline
71
+ elif in_quote and peek(1) == '"' and peek(2) in '\n ':
72
+ seek(2)
73
+ split_pos.append(pos)
74
+ rv.append(current)
75
+
76
+ # clean up, remove lines with only whitespace or punctuation
77
+ rv = [s.strip() for s in rv]
78
+ rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
79
+
80
+ return rv
81
 
82
  class EndpointHandler:
83
  def __init__(self, path=""):
 
108
 
109
  start_time = time.time()
110
 
111
+ texts = split_and_recombine_text(given_text)
112
+ audios = []
113
+
114
+ for t in progress.tqdm(texts):
115
+ inputs = self.processor(text=t, return_tensors="pt")
116
+ speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
117
 
118
+ audios.append(speech.numpy())
119
 
 
120
 
121
+ final_speech = np.concatenate(audios)
122
 
123
  run_time_total = time.time() - start_time
 
124
 
125
  # Return the expected response format
126
  return {
127
  "statusCode": 200,
128
  "body": {
129
+ "audio": final_speech, # Consider encoding this to a suitable format
130
  "sampling_rate": 16000,
 
 
131
  "run_time_total": str(run_time_total),
132
  }
133
  }