Dupaja commited on
Commit
d65f95a
·
1 Parent(s): 1ddbd3e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -23
handler.py CHANGED
@@ -1,37 +1,35 @@
1
- from typing import Dict
2
- from transformers import pipeline
3
  import torch
 
 
4
  import soundfile as sf
5
- import io
6
 
7
  class EndpointHandler:
8
- def __init__(self, path=""):
9
- self.synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
10
  self.embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
11
 
12
- def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
13
  text = data.get("inputs", "")
 
14
  speaker_embedding = torch.tensor(self.embeddings_dataset[7306]["xvector"]).unsqueeze(0)
 
 
15
 
16
- # Generate speech using the synthesiser
17
- speech = self.synthesiser(text, forward_params={"speaker_embeddings": speaker_embedding})
18
 
19
- # Convert numpy audio array to a WAV byte stream.
20
- audio_buffer = io.BytesIO()
21
- sf.write(file=audio_buffer, data=speech["audio"], samplerate=speech["sampling_rate"], format='WAV')
22
- audio_buffer.seek(0)
23
- audio_wav = audio_buffer.read()
24
-
25
- # Prepare the response headers.
26
- headers = {
27
- "Content-Type": "audio/wav"
28
- }
29
 
30
- # Create the response as raw audio bytes.
31
- response = {
32
  "statusCode": 200,
33
- "body": audio_wav,
34
- "headers": headers
 
 
35
  }
36
 
37
- return response
 
 
 
1
  import torch
2
+ from transformers import pipeline
3
+ from datasets import load_dataset
4
  import soundfile as sf
5
+ from huggingface_hub.inference_api import InferenceApi
6
 
7
  class EndpointHandler:
8
+ def __init__(self):
9
+ self.api = InferenceApi(repo_id="microsoft/speecht5_tts", task="text-to-speech")
10
  self.embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
11
 
12
+ def __call__(self, data):
13
  text = data.get("inputs", "")
14
+ # Extract speaker_embedding using the index from your dataset, or replace with your own logic.
15
  speaker_embedding = torch.tensor(self.embeddings_dataset[7306]["xvector"]).unsqueeze(0)
16
+ # Convert embedding to list to avoid serialization issues
17
+ speaker_embedding_list = speaker_embedding.tolist()
18
 
19
+ # Use the API to run the model
20
+ response = self.api(inputs=text, parameters={"forward_params": {"speaker_embeddings": speaker_embedding_list}}, options={"wait_for_model": True})
21
 
22
+ # Write the response audio to a file
23
+ # Note: This might not be possible in all environments, ensure this is suitable for your deployment
24
+ sf.write("speech.wav", response["audio"], samplerate=response["sampling_rate"])
 
 
 
 
 
 
 
25
 
26
+ # Return the expected response format
27
+ return {
28
  "statusCode": 200,
29
+ "body": {
30
+ "audio": response["audio"], # Consider encoding this to a suitable format
31
+ "sampling_rate": response["sampling_rate"]
32
+ }
33
  }
34
 
35
+ handler = EndpointHandler()