amedcj commited on
Commit
f70b4f1
·
verified ·
1 Parent(s): f91b0a2

Update app.py

Browse files

Updated app.py

Files changed (1) hide show
  1. app.py +170 -107
app.py CHANGED
@@ -1,107 +1,170 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
2
- import gradio as gr
3
- import numpy as np
4
- import scipy.io.wavfile
5
- import tempfile
6
- import scipy
7
-
8
- from transformers import VitsModel, AutoTokenizer
9
- import torch
10
- import scipy.io.wavfile
11
- import gradio as gr
12
- import tempfile
13
-
14
- # Load punctuation restoration pipeline
15
- punctuation_model_id = "oliverguhr/fullstop-punctuation-multilang-large"
16
- punct_tokenizer = AutoTokenizer.from_pretrained(punctuation_model_id)
17
- punct_model = AutoModelForTokenClassification.from_pretrained(punctuation_model_id)
18
- punct_pipe = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, aggregation_strategy="simple")
19
-
20
-
21
-
22
- # Load the model and tokenizer once
23
- model = VitsModel.from_pretrained("facebook/mms-tts-kmr-script_latin")
24
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kmr-script_latin")
25
-
26
- # Simple number-to-Kurmanji-word mapping
27
- num2word = {
28
- "0": "sifir",
29
- "1": "yek",
30
- "2": "du",
31
- "3": "sê",
32
- "4": "çar",
33
- "5": "pênc",
34
- "6": "şeş",
35
- "7": "heft",
36
- "8": "heşt",
37
- "9": "neh",
38
- "10": "deh",
39
- "11": "yanzdeh",
40
- "12": "dûwanzdeh",
41
- "13": "sêzdeh",
42
- "14": "çardeh",
43
- "15": "panzdeh",
44
- "16": "şanzdeh",
45
- "17": "hevzdeh",
46
- "18": "hejdeh",
47
- "19": "nozdeh",
48
- "20": "bîst",
49
- "30": "",
50
- "40": "çil",
51
- "50": "pêncî",
52
- "60": "şêst",
53
- "70": "heftê",
54
- "80": "heştê",
55
- "90": "nod",
56
- "100": "sed",
57
- # You can expand this...
58
- }
59
-
60
- import re
61
-
62
- def replace_numbers_with_words(text):
63
- def repl(match):
64
- num = match.group()
65
- return num2word.get(num, num) # fallback to number if unknown
66
- return re.sub(r'\b\d+\b', repl, text)
67
-
68
-
69
- def text_to_speech(text):
70
- # Convert text to input format
71
- text = restore_punctuation(text)
72
- text = replace_numbers_with_words(text)
73
- inputs = tokenizer(text, return_tensors="pt")
74
- with torch.no_grad():
75
- output = model(**inputs).waveform
76
-
77
- # Save the waveform to a temporary .wav file
78
- tmp_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
79
- scipy.io.wavfile.write(tmp_path, rate=model.config.sampling_rate, data=output.squeeze().numpy())
80
-
81
- return tmp_path
82
-
83
- def restore_punctuation(text):
84
- results = punct_pipe(text)
85
- punctuated = ""
86
- for token in results:
87
- word = token['word']
88
- punct = token.get('entity_group', '')
89
- # Simple heuristic to add punctuation after words when predicted
90
- if punct == "PERIOD":
91
- punctuated += word + ". "
92
- elif punct == "COMMA":
93
- punctuated += word + ", "
94
- else:
95
- punctuated += word + " "
96
- return punctuated.strip()
97
-
98
- # Gradio UI
99
- interface = gr.Interface(
100
- fn=text_to_speech,
101
- inputs=gr.Textbox(label="Enter Kurmanji Text"),
102
- outputs=gr.Audio(label="Generated Speech"),
103
- title="Kurmanji Text-to-Speech",
104
- description="Type Kurmanji Kurdish (Latin script) text and hear it spoken."
105
- )
106
-
107
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
2
+ import gradio as gr
3
+ import numpy as np
4
+ import scipy.io.wavfile
5
+ import tempfile
6
+ import os
7
+ from transformers import VitsModel, AutoTokenizer
8
+ import torch
9
+ import re
10
+
11
+ # Load punctuation restoration pipeline
12
+ try:
13
+ punctuation_model_id = "oliverguhr/fullstop-punctuation-multilang-large"
14
+ punct_tokenizer = AutoTokenizer.from_pretrained(punctuation_model_id)
15
+ punct_model = AutoModelForTokenClassification.from_pretrained(punctuation_model_id)
16
+ punct_pipe = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, aggregation_strategy="simple")
17
+ print("Punctuation model loaded successfully")
18
+ except Exception as e:
19
+ print(f"Error loading punctuation model: {e}")
20
+ punct_pipe = None
21
+
22
+ # Load the TTS model and tokenizer
23
+ try:
24
+ model = VitsModel.from_pretrained("facebook/mms-tts-kmr-script_latin")
25
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kmr-script_latin")
26
+ print("TTS model loaded successfully")
27
+ except Exception as e:
28
+ print(f"Error loading TTS model: {e}")
29
+ model = None
30
+ tokenizer = None
31
+
32
+ # Simple number-to-Kurmanji-word mapping
33
+ num2word = {
34
+ "0": "sifir",
35
+ "1": "yek",
36
+ "2": "du",
37
+ "3": "",
38
+ "4": "çar",
39
+ "5": "pênc",
40
+ "6": "şeş",
41
+ "7": "heft",
42
+ "8": "heşt",
43
+ "9": "neh",
44
+ "10": "deh",
45
+ "11": "yanzdeh",
46
+ "12": "dûwanzdeh",
47
+ "13": "sêzdeh",
48
+ "14": "çardeh",
49
+ "15": "panzdeh",
50
+ "16": "şanzdeh",
51
+ "17": "hevzdeh",
52
+ "18": "hejdeh",
53
+ "19": "nozdeh",
54
+ "20": "bîst",
55
+ "30": "",
56
+ "40": "çil",
57
+ "50": "pêncî",
58
+ "60": "şêst",
59
+ "70": "heftê",
60
+ "80": "heştê",
61
+ "90": "nod",
62
+ "100": "sed",
63
+ }
64
+
65
+ def replace_numbers_with_words(text):
66
+ def repl(match):
67
+ num = match.group()
68
+ return num2word.get(num, num) # fallback to number if unknown
69
+ return re.sub(r'\b\d+\b', repl, text)
70
+
71
+ def restore_punctuation(text):
72
+ if punct_pipe is None:
73
+ return text # Return original text if punctuation model failed to load
74
+
75
+ try:
76
+ results = punct_pipe(text)
77
+ punctuated = ""
78
+ for token in results:
79
+ word = token['word']
80
+ punct = token.get('entity_group', '')
81
+ # Simple heuristic to add punctuation after words when predicted
82
+ if punct == "PERIOD":
83
+ punctuated += word + ". "
84
+ elif punct == "COMMA":
85
+ punctuated += word + ", "
86
+ else:
87
+ punctuated += word + " "
88
+ return punctuated.strip()
89
+ except Exception as e:
90
+ print(f"Error in punctuation restoration: {e}")
91
+ return text # Return original text if punctuation fails
92
+
93
+ def text_to_speech(text):
94
+ try:
95
+ # Check if models are loaded
96
+ if model is None or tokenizer is None:
97
+ return None, "Error: TTS model failed to load"
98
+
99
+ if not text or text.strip() == "":
100
+ return None, "Please enter some text"
101
+
102
+ print(f"Processing text: {text}")
103
+
104
+ # Convert text to input format
105
+ processed_text = restore_punctuation(text)
106
+ processed_text = replace_numbers_with_words(processed_text)
107
+
108
+ print(f"Processed text: {processed_text}")
109
+
110
+ # Tokenize input
111
+ inputs = tokenizer(processed_text, return_tensors="pt")
112
+
113
+ # Generate audio
114
+ with torch.no_grad():
115
+ output = model(**inputs).waveform
116
+
117
+ # Convert to numpy array
118
+ waveform = output.squeeze().numpy()
119
+
120
+ # Create temporary file with proper cleanup
121
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
122
+ tmp_path = tmp_file.name
123
+
124
+ # Write audio file
125
+ scipy.io.wavfile.write(
126
+ tmp_path,
127
+ rate=model.config.sampling_rate,
128
+ data=waveform
129
+ )
130
+
131
+ print(f"Audio saved to: {tmp_path}")
132
+ return tmp_path
133
+
134
+ except Exception as e:
135
+ error_msg = f"Error in text_to_speech: {str(e)}"
136
+ print(error_msg)
137
+ return None
138
+
139
+ # Create Gradio interface with better error handling
140
+ def create_interface():
141
+ interface = gr.Interface(
142
+ fn=text_to_speech,
143
+ inputs=gr.Textbox(
144
+ label="Enter Kurmanji Text",
145
+ placeholder="Type your Kurmanji Kurdish text here...",
146
+ lines=3
147
+ ),
148
+ outputs=gr.Audio(label="Generated Speech"),
149
+ title="Kurmanji Text-to-Speech",
150
+ description="Type Kurmanji Kurdish (Latin script) text and hear it spoken.",
151
+ examples=[
152
+ ["Silav! Ez bi xêr im."],
153
+ ["Tu çawa yî?"],
154
+ ["Navê min Kurdî ye."]
155
+ ],
156
+ cache_examples=False
157
+ )
158
+ return interface
159
+
160
+ if __name__ == "__main__":
161
+ # Check if we're running on Hugging Face Spaces
162
+ if "SPACE_ID" in os.environ:
163
+ print("Running on Hugging Face Spaces")
164
+
165
+ interface = create_interface()
166
+ interface.launch(
167
+ share=False,
168
+ server_name="0.0.0.0",
169
+ server_port=7860
170
+ )