Dupaja commited on
Commit
bd9bb8e
·
1 Parent(s): d0c3581

Attempt to speed up

Browse files
Files changed (1) hide show
  1. handler.py +8 -6
handler.py CHANGED
@@ -18,12 +18,14 @@ class EndpointHandler:
18
  checkpoint = "Dupaja/speecht5_tts"
19
  vocoder_id = "Dupaja/speecht5_hifigan"
20
  dataset_id = "Dupaja/cmu-arctic-xvectors"
21
-
22
- self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
23
- self.processor = SpeechT5Processor.from_pretrained(checkpoint)
24
- self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)
25
- embeddings_dataset = load_dataset(dataset_id, split="validation", trust_remote_code=True)
26
- self.embeddings_dataset = embeddings_dataset
 
 
27
  self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
28
 
29
 
 
18
  checkpoint = "Dupaja/speecht5_tts"
19
  vocoder_id = "Dupaja/speecht5_hifigan"
20
  dataset_id = "Dupaja/cmu-arctic-xvectors"
21
+
22
+ with torch.device("cuda"):
23
+ self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
24
+ self.processor = SpeechT5Processor.from_pretrained(checkpoint)
25
+ self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)
26
+ embeddings_dataset = load_dataset(dataset_id, split="validation", trust_remote_code=True)
27
+ self.embeddings_dataset = embeddings_dataset
28
+
29
  self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
30
 
31