File size: 4,466 Bytes
c5cd944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06cfca8
 
 
 
 
 
 
 
c5cd944
 
d5b56e4
 
c5cd944
 
 
d5b56e4
 
 
c5cd944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f807815
 
 
c5cd944
 
 
 
 
 
f807815
c5cd944
f807815
 
 
c5cd944
 
 
 
 
 
 
 
 
2eebed8
c5cd944
 
 
 
2eebed8
 
 
 
c5cd944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f807815
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import logging
from typing import Any, Dict, List, Optional

import transformers

# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_model import UltravoxModel
from .ultravox_processing import UltravoxProcessor


class UltravoxPipeline(transformers.Pipeline):
    def __init__(
        self,
        model: UltravoxModel,
        tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
        audio_processor: Optional[transformers.ProcessorMixin] = None,
        **kwargs
    ):
        if tokenizer is None:
            try:
                tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model.config._name_or_path
                )
            except:
                tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model.config.text_model_id or model.config.text_config._name_or_path
                )

        if audio_processor is None:
            audio_processor = transformers.AutoProcessor.from_pretrained(
                model.config.audio_model_id or model.config.audio_config._name_or_path
            )

        self.processor = UltravoxProcessor(
            audio_processor=audio_processor,
            tokenizer=tokenizer,
            stack_factor=model.config.stack_factor,
        )

        super().__init__(model=model, tokenizer=tokenizer, **kwargs)

    def _sanitize_parameters(self, **kwargs):
        generation_kwargs = {}
        if "temperature" in kwargs:
            generation_kwargs["temperature"] = kwargs["temperature"]
        if "max_new_tokens" in kwargs:
            generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
        if "repetition_penalty" in kwargs:
            generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
        return {}, generation_kwargs, {}

    def preprocess(self, inputs: Dict[str, Any]):
        if "turns" in inputs:
            turns = inputs["turns"]
        else:
            turns = []

        if not turns or turns[-1]["role"] != "user":
            prompt = inputs.get("prompt", "<|audio|>")
            if "<|audio|>" not in prompt:
                logging.warning(
                    "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
                )
                prompt += " <|audio|>"
            turns.append({"role": "user", "content": prompt})

        text = self.processor.tokenizer.apply_chat_template(
            turns, add_generation_prompt=True, tokenize=False
        )

        # TODO: allow text-only mode?
        assert "audio" in inputs, "Audio input is required"

        if "sampling_rate" not in inputs:
            logging.warning(
                "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
            )

        output = self.processor(
            text=text,
            audio=inputs["audio"],
            sampling_rate=inputs.get("sampling_rate", 16000),
        )
        if "audio_values" in output:
            output["audio_values"] = output["audio_values"].to(self.model.dtype)

        return output

    def _forward(
        self,
        model_inputs: Dict[str, Any],
        temperature: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        repetition_penalty: float = 1.1,
    ) -> List[int]:
        temperature = temperature or None
        do_sample = temperature is not None

        terminators = [self.tokenizer.eos_token_id]
        if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
            terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))

        input_len = model_inputs["input_ids"].shape[1]

        outputs = self.model.generate(
            **model_inputs,
            do_sample=do_sample,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            eos_token_id=terminators
        )
        return outputs[0][input_len:]

    def postprocess(self, model_outputs) -> str:
        output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
        return output_text


transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
    "ultravox-pipeline",
    pipeline_class=UltravoxPipeline,
    pt_model=transformers.AutoModel,
    type="multimodal",
)