Rewrite to handle larger text
Browse files- handler.py +81 -11
handler.py
CHANGED
@@ -4,9 +4,80 @@ import torch
|
|
4 |
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
|
5 |
from datasets import load_dataset
|
6 |
import time
|
|
|
7 |
from typing import Dict, List, Any
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
class EndpointHandler:
|
12 |
def __init__(self, path=""):
|
@@ -37,27 +108,26 @@ class EndpointHandler:
|
|
37 |
|
38 |
start_time = time.time()
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
|
46 |
-
speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
|
47 |
|
48 |
-
|
49 |
|
50 |
run_time_total = time.time() - start_time
|
51 |
-
|
52 |
|
53 |
# Return the expected response format
|
54 |
return {
|
55 |
"statusCode": 200,
|
56 |
"body": {
|
57 |
-
"audio":
|
58 |
"sampling_rate": 16000,
|
59 |
-
"run_time_processor": str(run_time_processor),
|
60 |
-
"run_time_speech": str(run_time_speech),
|
61 |
"run_time_total": str(run_time_total),
|
62 |
}
|
63 |
}
|
|
|
4 |
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
|
5 |
from datasets import load_dataset
|
6 |
import time
|
7 |
+
import re
|
8 |
from typing import Dict, List, Any
|
9 |
|
10 |
+
#from tourtise utils
|
11 |
+
def split_and_recombine_text(text, desired_length=200, max_length=300):
|
12 |
+
"""Split text it into chunks of a desired length trying to keep sentences intact."""
|
13 |
+
# normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
|
14 |
+
text = re.sub(r'\n\n+', '\n', text)
|
15 |
+
text = re.sub(r'\s+', ' ', text)
|
16 |
+
text = re.sub(r'[“”]', '"', text)
|
17 |
+
|
18 |
+
rv = []
|
19 |
+
in_quote = False
|
20 |
+
current = ""
|
21 |
+
split_pos = []
|
22 |
+
pos = -1
|
23 |
+
end_pos = len(text) - 1
|
24 |
+
|
25 |
+
def seek(delta):
|
26 |
+
nonlocal pos, in_quote, current
|
27 |
+
is_neg = delta < 0
|
28 |
+
for _ in range(abs(delta)):
|
29 |
+
if is_neg:
|
30 |
+
pos -= 1
|
31 |
+
current = current[:-1]
|
32 |
+
else:
|
33 |
+
pos += 1
|
34 |
+
current += text[pos]
|
35 |
+
if text[pos] == '"':
|
36 |
+
in_quote = not in_quote
|
37 |
+
return text[pos]
|
38 |
+
|
39 |
+
def peek(delta):
|
40 |
+
p = pos + delta
|
41 |
+
return text[p] if p < end_pos and p >= 0 else ""
|
42 |
+
|
43 |
+
def commit():
|
44 |
+
nonlocal rv, current, split_pos
|
45 |
+
rv.append(current)
|
46 |
+
current = ""
|
47 |
+
split_pos = []
|
48 |
+
|
49 |
+
while pos < end_pos:
|
50 |
+
c = seek(1)
|
51 |
+
# do we need to force a split?
|
52 |
+
if len(current) >= max_length:
|
53 |
+
if len(split_pos) > 0 and len(current) > (desired_length / 2):
|
54 |
+
# we have at least one sentence and we are over half the desired length, seek back to the last split
|
55 |
+
d = pos - split_pos[-1]
|
56 |
+
seek(-d)
|
57 |
+
else:
|
58 |
+
# no full sentences, seek back until we are not in the middle of a word and split there
|
59 |
+
while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
|
60 |
+
c = seek(-1)
|
61 |
+
commit()
|
62 |
+
# check for sentence boundaries
|
63 |
+
elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')):
|
64 |
+
# seek forward if we have consecutive boundary markers but still within the max length
|
65 |
+
while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
|
66 |
+
c = seek(1)
|
67 |
+
split_pos.append(pos)
|
68 |
+
if len(current) >= desired_length:
|
69 |
+
commit()
|
70 |
+
# treat end of quote as a boundary if its followed by a space or newline
|
71 |
+
elif in_quote and peek(1) == '"' and peek(2) in '\n ':
|
72 |
+
seek(2)
|
73 |
+
split_pos.append(pos)
|
74 |
+
rv.append(current)
|
75 |
+
|
76 |
+
# clean up, remove lines with only whitespace or punctuation
|
77 |
+
rv = [s.strip() for s in rv]
|
78 |
+
rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
|
79 |
+
|
80 |
+
return rv
|
81 |
|
82 |
class EndpointHandler:
|
83 |
def __init__(self, path=""):
|
|
|
108 |
|
109 |
start_time = time.time()
|
110 |
|
111 |
+
texts = split_and_recombine_text(given_text)
|
112 |
+
audios = []
|
113 |
+
|
114 |
+
for t in progress.tqdm(texts):
|
115 |
+
inputs = self.processor(text=t, return_tensors="pt")
|
116 |
+
speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
|
117 |
|
118 |
+
audios.append(speech.numpy())
|
119 |
|
|
|
120 |
|
121 |
+
final_speech = np.concatenate(audios)
|
122 |
|
123 |
run_time_total = time.time() - start_time
|
|
|
124 |
|
125 |
# Return the expected response format
|
126 |
return {
|
127 |
"statusCode": 200,
|
128 |
"body": {
|
129 |
+
"audio": final_speech, # Consider encoding this to a suitable format
|
130 |
"sampling_rate": 16000,
|
|
|
|
|
131 |
"run_time_total": str(run_time_total),
|
132 |
}
|
133 |
}
|