File size: 6,864 Bytes
9f940c9
 
 
 
d65f95a
3b62ecd
c5c6476
279f06b
7b5f670
 
5f8cefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279f06b
 
5f8cefd
 
 
 
 
 
 
 
 
 
 
 
 
 
279f06b
 
 
 
5f8cefd
 
 
 
 
279f06b
 
 
 
 
 
 
 
 
 
 
 
 
 
5f8cefd
 
 
279f06b
 
 
e641a54
c5c6476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f940c9
aaf0168
7b5f670
9f940c9
5e7a614
 
 
 
 
0824f3c
5e7a614
bd9bb8e
4fd4826
617d2a4
4fd4826
 
9cf2190
 
 
bd9bb8e
9277202
 
aaf0168
9f940c9
8869945
9f940c9
 
5f8cefd
 
3b62ecd
 
279f06b
 
3b62ecd
c5c6476
 
 
1e92b35
c5c6476
 
c0208f8
279f06b
 
c0208f8
 
c5c6476
c0208f8
 
74cf751
d65f95a
d16558b
 
 
c5c6476
3b62ecd
fc3317e
d16558b
 
aaf0168
d65f95a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import librosa
import numpy as np
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
import time
import re
import inflect
from typing import Dict, List, Any

def contains_special_characters(s):
    return bool(re.search(r'[π“΅π–Ύπ“žπšŸπ”Ÿ]', s))

def check_punctuation(s):
    if s.endswith('.'):
        return '.'
    elif s.endswith(','):
        return ','
    elif s.endswith('!'):
        return '!'
    elif s.endswith('?'):
        return '?'
    else:
        return ''

def convert_numbers_to_text(input_string):
    p = inflect.engine()
    new_string = input_string

    # Find patterns like [6/7] or other number-character combinations
    mixed_patterns = re.findall(r'\[?\b\d+[^)\] ]*\]?', new_string)
    for pattern in mixed_patterns:
        # Isolate numbers from other characters
        numbers = re.findall(r'\d+', pattern)
        # Replace numbers with words within the pattern
        for number in numbers:
            number_word = p.number_to_words(number)
            pattern_with_words = re.sub(number_word, number, pattern, 1)
            new_string = new_string.replace(pattern, pattern_with_words)
        
    words = new_string.split()
    new_words = []
    
    for word in words:

        punct = check_punctuation(word)

        if contains_special_characters(word):    
            pass
        elif word.isdigit() and len(word) == 4:  # Check for years (4-digit numbers)
            year = int(word)
            if year < 2000:
                # Split the year into two parts
                first_part = year // 100
                second_part = year % 100
                # Convert each part to words and combine
                word = p.number_to_words(first_part) + " " + p.number_to_words(second_part)
            elif year < 9999:
                # Convert directly for year 2000 and beyond
                word = p.number_to_words(year)
        elif word.replace(',','').isdigit():  # Check for any other digits
            word = word.replace(',','')
            number = int(word)
            word = p.number_to_words(number).replace(',', '')
        
        
        new_words.append(word+punct)

    return ' '.join(new_words)

def split_and_recombine_text(text, desired_length=200, max_length=400):
    """Split text it into chunks of a desired length trying to keep sentences intact."""
    # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
    text = re.sub(r'\n\n+', '\n', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[β€œβ€]', '"', text)

    rv = []
    in_quote = False
    current = ""
    split_pos = []
    pos = -1
    end_pos = len(text) - 1

    def seek(delta):
        nonlocal pos, in_quote, current
        is_neg = delta < 0
        for _ in range(abs(delta)):
            if is_neg:
                pos -= 1
                current = current[:-1]
            else:
                pos += 1
                current += text[pos]
            if text[pos] == '"':
                in_quote = not in_quote
        return text[pos]

    def peek(delta):
        p = pos + delta
        return text[p] if p < end_pos and p >= 0 else ""

    def commit():
        nonlocal rv, current, split_pos
        rv.append(current)
        current = ""
        split_pos = []

    while pos < end_pos:
        c = seek(1)
        # do we need to force a split?
        if len(current) >= max_length:
            if len(split_pos) > 0 and len(current) > (desired_length / 2):
                # we have at least one sentence and we are over half the desired length, seek back to the last split
                d = pos - split_pos[-1]
                seek(-d)
            else:
                # no full sentences, seek back until we are not in the middle of a word and split there
                while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
                    c = seek(-1)
            commit()
        # check for sentence boundaries
        elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')):
            # seek forward if we have consecutive boundary markers but still within the max length
            while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
                c = seek(1)
            split_pos.append(pos)
            if len(current) >= desired_length:
                commit()
        # treat end of quote as a boundary if its followed by a space or newline
        elif in_quote and peek(1) == '"' and peek(2) in '\n ':
            seek(2)
            split_pos.append(pos)
    rv.append(current)

    # clean up, remove lines with only whitespace or punctuation
    rv = [s.strip() for s in rv]
    rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]

    return rv

class EndpointHandler:
    def __init__(self, path=""):

        #checkpoint = "microsoft/speecht5_tts"
        #vocoder_id = "microsoft/speecht5_hifigan"
        #dataset_id = "Matthijs/cmu-arctic-xvectors"
        
        checkpoint = "Dupaja/speecht5_tts"
        vocoder_id = "Dupaja/speecht5_hifigan"
        dataset_id = "Dupaja/cmu-arctic-xvectors"

        self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint, low_cpu_mem_usage=True)
        
        self.processor = SpeechT5Processor.from_pretrained(checkpoint)
        self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)

        embeddings_dataset = load_dataset(dataset_id, split="validation", trust_remote_code=True)
        self.embeddings_dataset = embeddings_dataset
            
        self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)



    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        
        given_text = data.get("inputs", "")
        given_text = given_text.replace('&','and')
        given_text = given_text.replace('-',' ')

        start_time = time.time()
        
        given_text = convert_numbers_to_text(given_text)

        texts = split_and_recombine_text(given_text)
        audios = []
    
        for t in texts:
            inputs = self.processor(text=t, return_tensors="pt")
            speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)

            audios.append(speech)
            #audios.append(speech.numpy())


        final_speech = np.concatenate(audios)

        run_time_total = time.time() - start_time

        # Return the expected response format
        return {
            "statusCode": 200,
            "body": {
                "audio": final_speech,  # Consider encoding this to a suitable format
                "sampling_rate": 16000,
                "run_time_total": str(run_time_total),
            }
        }

handler = EndpointHandler()