rbcurzon commited on
Commit
739d409
·
verified ·
1 Parent(s): c6e5581

refactor: Allow synthesize() to use gpu if available

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -165,27 +165,25 @@ async def translate_text(text: str,
165
  "tgtLang": tgtLang
166
  }
167
  return result_dict
168
-
169
  @app.post("/synthesize/")
170
  async def synthesize(text: str):
171
  model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
172
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
173
-
 
 
 
174
  inputs = tokenizer(text, return_tensors="pt")
175
-
 
 
176
  with torch.no_grad():
177
- output = model(**inputs).waveform
178
-
179
- data_np = output.numpy()
180
- data_np_squeezed = np.squeeze(data_np)
181
-
182
  temp_file = create_temp_filename()
183
 
184
- scipy.io.wavfile.write(
185
- temp_file,
186
- rate=model.config.sampling_rate,
187
- data=data_np_squeezed
188
- )
189
  logging.info(f"Synthesizing completed for text: {text}")
190
 
191
  return FileResponse(
 
165
  "tgtLang": tgtLang
166
  }
167
  return result_dict
 
168
  @app.post("/synthesize/")
169
  async def synthesize(text: str):
170
  model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
171
+ tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-tgl")
172
+
173
+ device = "cuda" if torch.cuda.is_available() else "cpu"
174
+ model.to(device)
175
+
176
  inputs = tokenizer(text, return_tensors="pt")
177
+ input_ids = inputs["input_ids"].to(device)
178
+
179
+
180
  with torch.no_grad():
181
+ outputs = model(input_ids)
182
+
183
+ speech = outputs["waveform"]
 
 
184
  temp_file = create_temp_filename()
185
 
186
+ torchaudio.save(temp_file, speech.cpu(), 16000)
 
 
 
 
187
  logging.info(f"Synthesizing completed for text: {text}")
188
 
189
  return FileResponse(