File size: 5,684 Bytes
87c4b82
 
543fed2
87c4b82
ee40bdf
87c4b82
460a4a6
 
 
 
 
543fed2
6a93de9
 
543fed2
6a93de9
 
460a4a6
 
6a93de9
 
543fed2
460a4a6
87c4b82
 
460a4a6
87c4b82
7e72b19
6a93de9
543fed2
87c4b82
543fed2
ccfb364
543fed2
ee40bdf
87c4b82
2ebb338
543fed2
ee40bdf
0663556
87c4b82
612a10c
b5fc8ee
 
40abfbf
5747d32
543fed2
 
 
 
6a93de9
 
612a10c
87c4b82
 
 
6a93de9
612a10c
543fed2
612a10c
87c4b82
 
612a10c
6a93de9
612a10c
 
87c4b82
612a10c
 
 
 
87c4b82
612a10c
87c4b82
612a10c
87c4b82
612a10c
87c4b82
612a10c
87c4b82
543fed2
6a93de9
 
 
 
 
 
 
 
 
 
 
 
 
543fed2
 
 
 
6a93de9
 
 
 
543fed2
 
6a93de9
 
 
 
543fed2
6a93de9
 
460a4a6
543fed2
 
460a4a6
6bf705e
543fed2
460a4a6
fdb64df
460a4a6
543fed2
 
460a4a6
543fed2
 
460a4a6
543fed2
 
 
 
 
 
 
460a4a6
543fed2
 
62124a7
543fed2
 
460a4a6
87c4b82
543fed2
87c4b82
 
 
543fed2
87c4b82
 
543fed2
 
 
87c4b82
 
6a93de9
543fed2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI

END_POINT = os.environ.get("ENDPOINT")
SECRET_KEY = os.environ.get("SECRETKEY")
USERS = os.environ.get("USERS")
PWD = os.environ.get("PWD")

def get_fn(model_name: str, **model_kwargs) -> Callable:
    """Create a chat function with the specified model."""
    
    # Instantiate an OpenAI client for a custom endpoint
    try:
        client = OpenAI(
            base_url=END_POINT,
            api_key=SECRET_KEY,  
        )
    except Exception as e:
        print(f"The API or base URL were not defined: {str(e)}")
        raise e  

    def predict(
        messages: list,  
        temperature: float,
        max_tokens: int,
        top_p: float
    ) -> Generator[str, None, None]:
        try:
            # Call the OpenAI API with the formatted messages
            response = client.chat.completions.create(
                model=model_name,  
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                stream=True, 
                response_format={"type": "text"},
            )

            response_text = ""
            for chunk in response:
                if len(chunk.choices[0].delta.content) > 0:
                    content = chunk.choices[0].delta.content
                    if content:
                        response_text += content
                        yield response_text.strip()
            
            if not response_text.strip():
                yield "I apologize, but I was unable to generate a response. Please try again."

        except Exception as e:
            print(f"Error during generation: {str(e)}")
            yield f"An error occurred: {str(e)}"

    return predict

def get_image_base64(url: str, ext: str) -> str:
    with open(url, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    return f"data:image/{ext};base64,{encoded_string}"

def handle_user_msg(message: str) -> str:
    if isinstance(message, str):
        return message
    elif isinstance(message, dict):
        if message.get("files"):
            ext = os.path.splitext(message["files"][-1])[1].strip(".").lower()
            if ext in ["png", "jpg", "jpeg", "gif", "pdf"]:
                encoded_str = get_image_base64(message["files"][-1], ext)
                return f"{message.get('text', '')}\n![Image]({encoded_str})"
            else:
                raise NotImplementedError(f"Unsupported file type: {ext}")
        else:
            return message.get("text", "")
    else:
        raise NotImplementedError("Unsupported message type")

def get_interface_args(pipeline: str):
    if pipeline == "chat":
        inputs = None
        outputs = None

        def preprocess(message, history):
            messages = []
            files = None
            for user_msg, assistant_msg in history:
                if assistant_msg is not None:
                    messages.append({"role": "user", "content": handle_user_msg(user_msg)})
                    messages.append({"role": "assistant", "content": assistant_msg})
                else:
                    files = user_msg
            if isinstance(message, str) and files is not None:
                message = {"text": message, "files": files}
            elif isinstance(message, dict) and files is not None:
                if not message.get("files"):
                    message["files"] = files
            messages.append({"role": "user", "content": handle_user_msg(message)})
            return {"messages": messages}

        postprocess = lambda x: x  # No additional postprocessing needed

    else:
        raise ValueError(f"Unsupported pipeline type: {pipeline}")
    return inputs, outputs, preprocess, postprocess

def registry(name: str = None, **kwargs) -> gr.ChatInterface:
    """Create a Gradio Interface with similar styling and parameters."""
    
    # Retrieving preprocess and postprocess functions
    _, _, preprocess, postprocess = get_interface_args("chat")
    
    # Getting the predict function
    predict_fn = get_fn(model_name=name, **kwargs)
    
    # Defining a wrapper function that integrates preprocessing and postprocessing
    def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
        # Preprocessing the inputs
        preprocessed = preprocess(message, history)
        
        # Extracting the preprocessed messages
        messages = preprocessed["messages"]
        
        # Calling the predict function and generate the response
        response_generator = predict_fn(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p
        )
        
        # Collecting the generated response
        response = ""
        for partial_response in response_generator:
            response = partial_response  # Gradio will handle streaming
            yield response

    # Creating the Gradio ChatInterface with the wrapper function
    interface = gr.ChatInterface(
        fn=wrapper,
        additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
        additional_inputs=[
            gr.Textbox(
                value="You are a helpful AI assistant.",
                label="System prompt"
            ),
            gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
            gr.Slider(128, 4096, value=1024, label="Max new tokens"),
            gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
        ],
    )
    
    return interface