Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from typing import Callable, Generator | |
import base64 | |
from openai import OpenAI | |
END_POINT = os.environ.get("ENDPOINT") | |
SECRET_KEY = os.environ.get("SECRETKEY") | |
USERS = os.environ.get("USERS") | |
PWD = os.environ.get("PWD") | |
def get_fn(model_name: str, **model_kwargs) -> Callable: | |
"""Create a chat function with the specified model.""" | |
# Instantiate an OpenAI client for a custom endpoint | |
try: | |
client = OpenAI( | |
base_url=END_POINT, | |
api_key=SECRET_KEY, | |
) | |
except Exception as e: | |
print(f"The API or base URL were not defined: {str(e)}") | |
raise e | |
def predict( | |
messages: list, | |
temperature: float, | |
max_tokens: int, | |
top_p: float | |
) -> Generator[str, None, None]: | |
try: | |
# Call the OpenAI API with the formatted messages | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
stream=True, | |
response_format={"type": "text"}, | |
) | |
response_text = "" | |
for chunk in response: | |
if len(chunk.choices[0].delta.content) > 0: | |
content = chunk.choices[0].delta.content | |
if content: | |
response_text += content | |
yield response_text.strip() | |
if not response_text.strip(): | |
yield "I apologize, but I was unable to generate a response. Please try again." | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
yield f"An error occurred: {str(e)}" | |
return predict | |
def get_image_base64(url: str, ext: str) -> str: | |
with open(url, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
return f"data:image/{ext};base64,{encoded_string}" | |
def handle_user_msg(message: str) -> str: | |
if isinstance(message, str): | |
return message | |
elif isinstance(message, dict): | |
if message.get("files"): | |
ext = os.path.splitext(message["files"][-1])[1].strip(".").lower() | |
if ext in ["png", "jpg", "jpeg", "gif", "pdf"]: | |
encoded_str = get_image_base64(message["files"][-1], ext) | |
return f"{message.get('text', '')}\n" | |
else: | |
raise NotImplementedError(f"Unsupported file type: {ext}") | |
else: | |
return message.get("text", "") | |
else: | |
raise NotImplementedError("Unsupported message type") | |
def get_interface_args(pipeline: str): | |
if pipeline == "chat": | |
inputs = None | |
outputs = None | |
def preprocess(message, history): | |
messages = [] | |
files = None | |
for user_msg, assistant_msg in history: | |
if assistant_msg is not None: | |
messages.append({"role": "user", "content": handle_user_msg(user_msg)}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
else: | |
files = user_msg | |
if isinstance(message, str) and files is not None: | |
message = {"text": message, "files": files} | |
elif isinstance(message, dict) and files is not None: | |
if not message.get("files"): | |
message["files"] = files | |
messages.append({"role": "user", "content": handle_user_msg(message)}) | |
return {"messages": messages} | |
postprocess = lambda x: x # No additional postprocessing needed | |
else: | |
raise ValueError(f"Unsupported pipeline type: {pipeline}") | |
return inputs, outputs, preprocess, postprocess | |
def registry(name: str = None, **kwargs) -> gr.ChatInterface: | |
"""Create a Gradio Interface with similar styling and parameters.""" | |
# Retrieving preprocess and postprocess functions | |
_, _, preprocess, postprocess = get_interface_args("chat") | |
# Getting the predict function | |
predict_fn = get_fn(model_name=name, **kwargs) | |
# Defining a wrapper function that integrates preprocessing and postprocessing | |
def wrapper(message, history, system_prompt, temperature, max_tokens, top_p): | |
# Preprocessing the inputs | |
preprocessed = preprocess(message, history) | |
# Extracting the preprocessed messages | |
messages = preprocessed["messages"] | |
# Calling the predict function and generate the response | |
response_generator = predict_fn( | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p | |
) | |
# Collecting the generated response | |
response = "" | |
for partial_response in response_generator: | |
response = partial_response # Gradio will handle streaming | |
yield response | |
# Creating the Gradio ChatInterface with the wrapper function | |
interface = gr.ChatInterface( | |
fn=wrapper, | |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False), | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are a helpful AI assistant.", | |
label="System prompt" | |
), | |
gr.Slider(0.0, 1.0, value=0.7, label="Temperature"), | |
gr.Slider(128, 4096, value=1024, label="Max new tokens"), | |
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"), | |
], | |
) | |
return interface | |