amedcj commited on
Commit
97b7215
·
verified ·
1 Parent(s): 17058e1

Update app.py

Browse files

Updated app.py

Files changed (1) hide show
  1. app.py +187 -187
app.py CHANGED
@@ -1,188 +1,188 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
2
- import gradio as gr
3
- import numpy as np
4
- import scipy.io.wavfile
5
- import tempfile
6
- import os
7
- from transformers import VitsModel, AutoTokenizer
8
- import torch
9
- import re
10
- import traceback
11
-
12
- print("Starting application...")
13
-
14
- # Global variables for models
15
- punct_pipe = None
16
- model = None
17
- tokenizer = None
18
-
19
- def load_models():
20
- global punct_pipe, model, tokenizer
21
-
22
- print("Loading punctuation model...")
23
- try:
24
- punctuation_model_id = "oliverguhr/fullstop-punctuation-multilang-large"
25
- punct_tokenizer = AutoTokenizer.from_pretrained(punctuation_model_id)
26
- punct_model = AutoModelForTokenClassification.from_pretrained(punctuation_model_id)
27
- punct_pipe = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, aggregation_strategy="simple")
28
- print("✓ Punctuation model loaded successfully")
29
- except Exception as e:
30
- print(f"✗ Error loading punctuation model: {e}")
31
- punct_pipe = None
32
-
33
- print("Loading TTS model...")
34
- try:
35
- model = VitsModel.from_pretrained("facebook/mms-tts-kmr-script_latin")
36
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kmr-script_latin")
37
- print("✓ TTS model loaded successfully")
38
- except Exception as e:
39
- print(f"✗ Error loading TTS model: {e}")
40
- model = None
41
- tokenizer = None
42
-
43
- # Load models at startup
44
- load_models()
45
-
46
- # Simple number-to-Kurmanji-word mapping
47
- num2word = {
48
- "0": "sifir", "1": "yek", "2": "du", "3": "sê", "4": "çar", "5": "pênc",
49
- "6": "şeş", "7": "heft", "8": "heşt", "9": "neh", "10": "deh"
50
- }
51
-
52
- def replace_numbers_with_words(text):
53
- def repl(match):
54
- num = match.group()
55
- return num2word.get(num, num)
56
- return re.sub(r'\b\d+\b', repl, text)
57
-
58
- def restore_punctuation(text):
59
- if punct_pipe is None:
60
- print("Punctuation model not available, skipping...")
61
- return text
62
-
63
- try:
64
- results = punct_pipe(text)
65
- punctuated = ""
66
- for token in results:
67
- word = token['word']
68
- punct = token.get('entity_group', '')
69
- if punct == "PERIOD":
70
- punctuated += word + ". "
71
- elif punct == "COMMA":
72
- punctuated += word + ", "
73
- else:
74
- punctuated += word + " "
75
- return punctuated.strip()
76
- except Exception as e:
77
- print(f"Punctuation error: {e}")
78
- return text
79
-
80
- def text_to_speech(text):
81
- print(f"=== TTS Function Called ===")
82
- print(f"Input text: '{text}'")
83
-
84
- try:
85
- # Basic validation
86
- if not text or text.strip() == "":
87
- error_msg = "Please enter some text"
88
- print(f"Error: {error_msg}")
89
- return None
90
-
91
- # Check if models are loaded
92
- if model is None or tokenizer is None:
93
- error_msg = "TTS model not loaded properly"
94
- print(f"Error: {error_msg}")
95
- return None
96
-
97
- print("Processing text...")
98
-
99
- # Process text
100
- processed_text = text.strip() # Start simple, skip punctuation for now
101
- processed_text = replace_numbers_with_words(processed_text)
102
- print(f"Processed text: '{processed_text}'")
103
-
104
- # Tokenize
105
- print("Tokenizing...")
106
- inputs = tokenizer(processed_text, return_tensors="pt")
107
- print(f"Tokenized successfully, input_ids shape: {inputs['input_ids'].shape}")
108
-
109
- # Generate audio
110
- print("Generating audio...")
111
- with torch.no_grad():
112
- output = model(**inputs).waveform
113
- print(f"Audio generated, shape: {output.shape}")
114
-
115
- # Convert to numpy
116
- waveform = output.squeeze().numpy()
117
- print(f"Waveform shape: {waveform.shape}")
118
-
119
- # Save to file
120
- print("Saving audio file...")
121
- tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
122
- tmp_path = tmp_file.name
123
- tmp_file.close()
124
-
125
- scipy.io.wavfile.write(
126
- tmp_path,
127
- rate=model.config.sampling_rate,
128
- data=waveform
129
- )
130
-
131
- print(f"✓ Audio saved to: {tmp_path}")
132
- print("=== TTS Function Completed Successfully ===")
133
- return tmp_path
134
-
135
- except Exception as e:
136
- error_msg = f"Error in TTS: {str(e)}"
137
- print(f"✗ {error_msg}")
138
- print("Full traceback:")
139
- traceback.print_exc()
140
- return None
141
-
142
- # Simple test function to verify Gradio is working
143
- def test_function(text):
144
- print(f"Test function called with: {text}")
145
- return f"You entered: {text}"
146
-
147
- # Create a simple interface first to test
148
- print("Creating Gradio interface...")
149
-
150
- # Option 1: Simple test interface (uncomment to test basic functionality)
151
- # interface = gr.Interface(
152
- # fn=test_function,
153
- # inputs=gr.Textbox(label="Test Input"),
154
- # outputs=gr.Textbox(label="Test Output"),
155
- # title="Test Interface"
156
- # )
157
-
158
- # Option 2: Full TTS interface
159
- interface = gr.Interface(
160
- fn=text_to_speech,
161
- inputs=gr.Textbox(
162
- label="Enter Kurmanji Text",
163
- placeholder="e.g. Silav! Ez bi xêr im.",
164
- lines=2,
165
- value="" # Default empty value
166
- ),
167
- outputs=gr.Audio(label="Generated Speech"),
168
- title="Kurmanji Text-to-Speech",
169
- description="Enter Kurmanji Kurdish text to convert to speech.",
170
- examples=[
171
- ["Silav"],
172
- ["Ez bi xêr im"],
173
- ["Spas"]
174
- ],
175
- cache_examples=False,
176
- flagging_mode="never"
177
- )
178
-
179
- print("Launching interface...")
180
-
181
- if __name__ == "__main__":
182
- interface.launch(
183
- debug=True,
184
- share=False,
185
- show_error=True,
186
- server_name="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1",
187
- server_port=7860
188
  )
 
1
+ from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
2
+ import gradio as gr
3
+ import numpy as np
4
+ import scipy.io.wavfile
5
+ import tempfile
6
+ import os
7
+ from transformers import VitsModel, AutoTokenizer
8
+ import torch
9
+ import re
10
+ import traceback
11
+
12
+ print("Starting application...")
13
+
14
+ # Global variables for models
15
+ punct_pipe = None
16
+ model = None
17
+ tokenizer = None
18
+
19
+ def load_models():
20
+ global punct_pipe, model, tokenizer
21
+
22
+ print("Loading punctuation model...")
23
+ try:
24
+ punctuation_model_id = "oliverguhr/fullstop-punctuation-multilang-large"
25
+ punct_tokenizer = AutoTokenizer.from_pretrained(punctuation_model_id)
26
+ punct_model = AutoModelForTokenClassification.from_pretrained(punctuation_model_id)
27
+ punct_pipe = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, aggregation_strategy="simple")
28
+ print("✓ Punctuation model loaded successfully")
29
+ except Exception as e:
30
+ print(f"✗ Error loading punctuation model: {e}")
31
+ punct_pipe = None
32
+
33
+ print("Loading TTS model...")
34
+ try:
35
+ model = VitsModel.from_pretrained("facebook/mms-tts-kmr-script_latin")
36
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kmr-script_latin")
37
+ print("✓ TTS model loaded successfully")
38
+ except Exception as e:
39
+ print(f"✗ Error loading TTS model: {e}")
40
+ model = None
41
+ tokenizer = None
42
+
43
+ # Load models at startup
44
+ load_models()
45
+
46
+ # Simple number-to-Kurmanji-word mapping
47
+ num2word = {
48
+ "0": "sifir", "1": "yek", "2": "du", "3": "sê", "4": "çar", "5": "pênc",
49
+ "6": "şeş", "7": "heft", "8": "heşt", "9": "neh", "10": "deh"
50
+ }
51
+
52
+ def replace_numbers_with_words(text):
53
+ def repl(match):
54
+ num = match.group()
55
+ return num2word.get(num, num)
56
+ return re.sub(r'\b\d+\b', repl, text)
57
+
58
+ def restore_punctuation(text):
59
+ if punct_pipe is None:
60
+ print("Punctuation model not available, skipping...")
61
+ return text
62
+
63
+ try:
64
+ results = punct_pipe(text)
65
+ punctuated = ""
66
+ for token in results:
67
+ word = token['word']
68
+ punct = token.get('entity_group', '')
69
+ if punct == "PERIOD":
70
+ punctuated += word + ". "
71
+ elif punct == "COMMA":
72
+ punctuated += word + ", "
73
+ else:
74
+ punctuated += word + " "
75
+ return punctuated.strip()
76
+ except Exception as e:
77
+ print(f"Punctuation error: {e}")
78
+ return text
79
+
80
+ def text_to_speech(text):
81
+ print(f"=== TTS Function Called ===")
82
+ print(f"Input text: '{text}'")
83
+
84
+ try:
85
+ # Basic validation
86
+ if not text or text.strip() == "":
87
+ error_msg = "Please enter some text"
88
+ print(f"Error: {error_msg}")
89
+ return None
90
+
91
+ # Check if models are loaded
92
+ if model is None or tokenizer is None:
93
+ error_msg = "TTS model not loaded properly"
94
+ print(f"Error: {error_msg}")
95
+ return None
96
+
97
+ print("Processing text...")
98
+
99
+ # Process text
100
+ processed_text = text.strip() # Start simple, skip punctuation for now
101
+ processed_text = replace_numbers_with_words(processed_text)
102
+ print(f"Processed text: '{processed_text}'")
103
+
104
+ # Tokenize
105
+ print("Tokenizing...")
106
+ inputs = tokenizer(processed_text, return_tensors="pt")
107
+ print(f"Tokenized successfully, input_ids shape: {inputs['input_ids'].shape}")
108
+
109
+ # Generate audio
110
+ print("Generating audio...")
111
+ with torch.no_grad():
112
+ output = model(**inputs).waveform
113
+ print(f"Audio generated, shape: {output.shape}")
114
+
115
+ # Convert to numpy
116
+ waveform = output.squeeze().numpy()
117
+ print(f"Waveform shape: {waveform.shape}")
118
+
119
+ # Save to file
120
+ print("Saving audio file...")
121
+ tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
122
+ tmp_path = tmp_file.name
123
+ tmp_file.close()
124
+
125
+ scipy.io.wavfile.write(
126
+ tmp_path,
127
+ rate=model.config.sampling_rate,
128
+ data=waveform
129
+ )
130
+
131
+ print(f"✓ Audio saved to: {tmp_path}")
132
+ print("=== TTS Function Completed Successfully ===")
133
+ return tmp_path
134
+
135
+ except Exception as e:
136
+ error_msg = f"Error in TTS: {str(e)}"
137
+ print(f"✗ {error_msg}")
138
+ print("Full traceback:")
139
+ traceback.print_exc()
140
+ return None
141
+
142
+ # Simple test function to verify Gradio is working
143
+ def test_function(text):
144
+ print(f"Test function called with: {text}")
145
+ return f"You entered: {text}"
146
+
147
+ # Create a simple interface first to test
148
+ print("Creating Gradio interface...")
149
+
150
+ # Option 1: Simple test interface (uncomment to test basic functionality)
151
+ # interface = gr.Interface(
152
+ # fn=test_function,
153
+ # inputs=gr.Textbox(label="Test Input"),
154
+ # outputs=gr.Textbox(label="Test Output"),
155
+ # title="Test Interface"
156
+ # )
157
+
158
+ # Option 2: Full TTS interface
159
+ interface = gr.Interface(
160
+ fn=text_to_speech,
161
+ inputs=gr.Textbox(
162
+ label="Enter Kurmanji Text",
163
+ placeholder="e.g. Silav! Ez bi xêr im.",
164
+ lines=2,
165
+ value="" # Default empty value
166
+ ),
167
+ outputs=gr.Audio(label="Generated Speech"),
168
+ title="Kurmanji Text-to-Speech",
169
+ description="Enter Kurmanji Kurdish text to convert to speech.",
170
+ examples=[
171
+ ["Silav"],
172
+ ["Ez bi xêr im"],
173
+ ["Spas"]
174
+ ],
175
+ cache_examples=False,
176
+ flagging_mode="never"
177
+ )
178
+
179
+ print("Launching interface...")
180
+
181
+ if __name__ == "__main__":
182
+ interface.launch(
183
+ debug=True,
184
+ share=False,
185
+ show_error=True,
186
+ server_name="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1",
187
+ server_port=7860
188
  )