File size: 3,353 Bytes
5f4691e
e20cb99
 
 
 
ae519a4
5188dae
 
5f4691e
 
 
 
2c3da68
c9760a6
ae519a4
 
 
 
 
 
 
 
 
 
 
 
 
c9760a6
5f4691e
ae519a4
c9760a6
5188dae
c9760a6
 
 
5188dae
3981ed2
ae519a4
 
 
 
 
 
 
 
 
 
 
 
 
 
5188dae
c9760a6
 
 
 
 
 
5188dae
 
 
 
ae519a4
5f4691e
e20cb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4691e
 
 
ae519a4
e20cb99
 
ae519a4
e20cb99
5f4691e
c9760a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4691e
 
 
 
 
c9760a6
5f4691e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import numpy as np

from transformers import pipeline
from custom_chat_interface import CustomChatInterface

from llama_cpp import Llama
from llama_cpp.llama_chat_format import MoondreamChatHandler

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""


class MyModel:
    def __init__(self):
        self.client = None
        self.current_model = ""

    def respond(
        self,
        message,
        history: list[tuple[str, str]],
        model,
        system_message,
        max_tokens,
        temperature,
        top_p,
    ):
        if model != self.current_model or self.current_model is None:
            model_id, filename = model.split(",")
            client = Llama.from_pretrained(
                repo_id=model_id.strip(),
                filename=f"*{filename.strip()}*.gguf",
                n_ctx=2048,  # n_ctx should be increased to accommodate the image embedding
            )

            self.client = client
            self.current_model = model

        messages = [{"role": "system", "content": system_message}]

        for val in history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})

        messages.append({"role": "user", "content": message})

        response = ""
        for message in self.client.create_chat_completion(
            messages,
            temperature=temperature,
            top_p=top_p,
            stream=True,
            max_tokens=max_tokens,
        ):
            delta = message["choices"][0]["delta"]
            if "content" in delta:
                response += delta["content"]
                yield response


transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
def transcribe(audio):
    sr, y = audio

    # Convert to mono if stereo
    if y.ndim > 1:
        y = y.mean(axis=1)

    y = y.astype(np.float32)
    y /= np.max(np.abs(y))

    text = transcriber({"sampling_rate": sr, "raw": y})["text"]
    return text


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
my_model = MyModel()
model_choices = ["lab2-as/lora_model_gguf, Q4", "lab2-as/lora_model_no_quant_gguf, Q4"]
demo = CustomChatInterface(
    my_model.respond,
    transcriber=transcribe,
    additional_inputs=[
        gr.Dropdown(
            choices=model_choices,
            value=model_choices[0],
            label="Select Model",
        ),
        gr.Textbox(
            value="You are a friendly Chatbot.",
            label="System message",
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=128,
            step=1,
            label="Max new tokens",
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature",
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (Nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()