arthur-stackadoc-com commited on
Commit
5593e7e
·
1 Parent(s): 206504f
Files changed (2) hide show
  1. call_handler.py +9 -4
  2. handler.py +20 -8
call_handler.py CHANGED
@@ -1,14 +1,19 @@
1
  from handler import EndpointHandler
 
2
 
3
  # init handler
4
  my_handler = EndpointHandler(path=".")
5
 
6
  # prepare sample payload
7
- non_holiday_payload = {"inputs": "I am quite excited how this will turn out", }
8
- holiday_payload = {"inputs": "Today is a though day"}
 
 
9
 
10
  # test the handler
11
- non_holiday_pred=my_handler(non_holiday_payload)
 
12
 
13
  # show results
14
- print("holiday_payload", holiday_payload)
 
 
1
  from handler import EndpointHandler
2
+ import base64
3
 
4
  # init handler
5
  my_handler = EndpointHandler(path=".")
6
 
7
  # prepare sample payload
8
+ text_payload = {"text": "I am quite excited how this will turn out"}
9
+ audio_payload = {"audio": base64.b64encode(
10
+ open("/home/arthur/data/musicdb/split_demucses/474499231_456751864 --__-- Breakbot - Programme_mp3_drums/chunk_11.wav", 'rb').read())}
11
+
12
 
13
  # test the handler
14
+ # text_pred = my_handler(text_payload)
15
+ audio_pred = my_handler(audio_payload)
16
 
17
  # show results
18
+ # print("text_pred", text_pred)
19
+ print("audio_pred", audio_pred)
handler.py CHANGED
@@ -1,6 +1,11 @@
 
1
  from typing import Dict, List, Any
 
 
2
  from transformers import ClapModel, ClapProcessor
3
- import gc
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
@@ -12,13 +17,20 @@ class EndpointHandler:
12
  """
13
  data args:
14
  inputs (:obj: `str`)
15
- date (:obj: `str`)
16
  Return:
17
  A :obj:`list` | `dict`: will be serialized and returned
18
  """
19
- print(type(data))
20
- query = data['inputs']
21
- text_inputs = self.processor(text=query, return_tensors="pt")
22
- text_embed = self.model.get_text_features(**text_inputs)[0]
23
- gc.collect()
24
- return text_embed.detach().numpy()
 
 
 
 
 
 
 
 
 
1
+ # import io
2
  from typing import Dict, List, Any
3
+
4
+ # import librosa
5
  from transformers import ClapModel, ClapProcessor
6
+ # import gc
7
+ # import base64
8
+
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
 
17
  """
18
  data args:
19
  inputs (:obj: `str`)
 
20
  Return:
21
  A :obj:`list` | `dict`: will be serialized and returned
22
  """
23
+ # print(type(data))
24
+ if 'text' in data:
25
+ query = data['text']
26
+ text_inputs = self.processor(text=query, return_tensors="pt")
27
+ text_embed = self.model.get_text_features(**text_inputs)[0]
28
+ return text_embed.detach().numpy()
29
+
30
+ # if 'audio' in data:
31
+ # # Load the audio data into librosa
32
+ # audio_buffer = io.BytesIO(base64.b64decode(data['audio']))
33
+ # y, sr = librosa.load(audio_buffer, sr=48000)
34
+ # inputs = self.processor(audios=y, sampling_rate=sr, return_tensors="pt")
35
+ # embedding = self.model.get_audio_features(**inputs)[0]
36
+ # gc.collect()