File size: 4,466 Bytes
c5cd944 06cfca8 c5cd944 d5b56e4 c5cd944 d5b56e4 c5cd944 f807815 c5cd944 f807815 c5cd944 f807815 c5cd944 2eebed8 c5cd944 2eebed8 c5cd944 f807815 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import logging
from typing import Any, Dict, List, Optional
import transformers
# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_model import UltravoxModel
from .ultravox_processing import UltravoxProcessor
class UltravoxPipeline(transformers.Pipeline):
def __init__(
self,
model: UltravoxModel,
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
audio_processor: Optional[transformers.ProcessorMixin] = None,
**kwargs
):
if tokenizer is None:
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config._name_or_path
)
except:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config.text_model_id or model.config.text_config._name_or_path
)
if audio_processor is None:
audio_processor = transformers.AutoProcessor.from_pretrained(
model.config.audio_model_id or model.config.audio_config._name_or_path
)
self.processor = UltravoxProcessor(
audio_processor=audio_processor,
tokenizer=tokenizer,
stack_factor=model.config.stack_factor,
)
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
def _sanitize_parameters(self, **kwargs):
generation_kwargs = {}
if "temperature" in kwargs:
generation_kwargs["temperature"] = kwargs["temperature"]
if "max_new_tokens" in kwargs:
generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
if "repetition_penalty" in kwargs:
generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
return {}, generation_kwargs, {}
def preprocess(self, inputs: Dict[str, Any]):
if "turns" in inputs:
turns = inputs["turns"]
else:
turns = []
if not turns or turns[-1]["role"] != "user":
prompt = inputs.get("prompt", "<|audio|>")
if "<|audio|>" not in prompt:
logging.warning(
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
)
prompt += " <|audio|>"
turns.append({"role": "user", "content": prompt})
text = self.processor.tokenizer.apply_chat_template(
turns, add_generation_prompt=True, tokenize=False
)
# TODO: allow text-only mode?
assert "audio" in inputs, "Audio input is required"
if "sampling_rate" not in inputs:
logging.warning(
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
)
output = self.processor(
text=text,
audio=inputs["audio"],
sampling_rate=inputs.get("sampling_rate", 16000),
)
if "audio_values" in output:
output["audio_values"] = output["audio_values"].to(self.model.dtype)
return output
def _forward(
self,
model_inputs: Dict[str, Any],
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: float = 1.1,
) -> List[int]:
temperature = temperature or None
do_sample = temperature is not None
terminators = [self.tokenizer.eos_token_id]
if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
input_len = model_inputs["input_ids"].shape[1]
outputs = self.model.generate(
**model_inputs,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
eos_token_id=terminators
)
return outputs[0][input_len:]
def postprocess(self, model_outputs) -> str:
output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
return output_text
transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
"ultravox-pipeline",
pipeline_class=UltravoxPipeline,
pt_model=transformers.AutoModel,
type="multimodal",
) |