Dupaja commited on
Commit
9f940c9
·
1 Parent(s): 8869945

Updating using SpeechT5 Article

Browse files

https://huggingface.co/blog/speecht5

Files changed (1) hide show
  1. handler.py +27 -8
handler.py CHANGED
@@ -1,28 +1,47 @@
1
- from huggingface_hub import InferenceClient
 
 
 
2
  from datasets import load_dataset
3
  import soundfile as sf
4
  from typing import Dict, List, Any
5
 
 
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- self.client = InferenceClient(repo_id="microsoft/speecht5_tts", task="text-to-speech")
 
 
 
 
 
9
  self.embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
10
 
 
11
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
12
- text = data.get("inputs", "")
13
- speaker_embedding = self.embeddings_dataset['xvector'][7306].unsqueeze(0).tolist()
 
 
 
 
 
 
 
 
 
14
 
15
- response = self.client(payload={"inputs": text, "forward_params": {"speaker_embeddings": speaker_embedding}}, options={"wait_for_model": True})
16
 
17
  # Write the response audio to a file
18
- sf.write("speech.wav", response.audio, response.sampling_rate)
19
 
20
  # Return the expected response format
21
  return {
22
  "statusCode": 200,
23
  "body": {
24
- "audio": response.audio, # Consider encoding this to a suitable format
25
- "sampling_rate": response.sampling_rate
26
  }
27
  }
28
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
5
  from datasets import load_dataset
6
  import soundfile as sf
7
  from typing import Dict, List, Any
8
 
9
+
10
+
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
+
14
+ checkpoint = "microsoft/speecht5_tts"
15
+
16
+ self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
17
+ self.processor = SpeechT5Processor.from_pretrained(checkpoint)
18
+ self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
19
  self.embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
20
 
21
+
22
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
+
24
+ given_text = data.get("inputs", "")
25
+
26
+
27
+ speaker_embeddings = torch.tensor(self.embeddings_dataset[7306]["xvector"]).unsqueeze(0)
28
+
29
+ inputs = self.processor(text=given_text, return_tensors="pt")
30
+
31
+ speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
32
+
33
+
34
 
 
35
 
36
  # Write the response audio to a file
37
+ sf.write("current_sample.wav", speech.numpy(), samplerate=16000)
38
 
39
  # Return the expected response format
40
  return {
41
  "statusCode": 200,
42
  "body": {
43
+ "audio": speech.numpy(), # Consider encoding this to a suitable format
44
+ "sampling_rate": 16000
45
  }
46
  }
47