from typing import Dict, List, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda") def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) duration = data.pop("duration", 5.0) duration = max(min(duration, 30.0), 0.0) max_new_tokens = int(duration * self.model.config.audio_encoder.frame_rate) # preprocess inputs = self.processor( text=[inputs], padding=True, return_tensors="pt",).to("cuda") # pass inputs with all kwargs in data if parameters is not None: outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, **parameters) else: outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) # postprocess the prediction prediction = outputs[0].cpu().numpy() return [{"generated_text": prediction}]