marco-o1 / helper.py
rodrigomasini's picture
Update helper.py
52a4ff6 verified
import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI
END_POINT = os.environ.get("ENDPOINT")
SECRET_KEY = os.environ.get("SECRETKEY")
USERS = os.environ.get("USERS")
PWD = os.environ.get("PWD")
def get_fn(model_name: str, **model_kwargs) -> Callable:
"""Create a chat function with the specified model."""
# Instantiate an OpenAI client for a custom endpoint
try:
client = OpenAI(
base_url=END_POINT,
api_key=SECRET_KEY,
)
except Exception as e:
print(f"The API or base URL were not defined: {str(e)}")
raise e
def predict(
messages: list,
temperature: float,
max_tokens: int,
top_p: float
) -> Generator[str, None, None]:
try:
# Call the OpenAI API with the formatted messages
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
response_format={"type": "text"},
)
response_text = ""
for chunk in response:
if len(chunk.choices[0].delta.content) > 0:
content = chunk.choices[0].delta.content
if content:
response_text += content
yield response_text.strip()
if not response_text.strip():
yield "I apologize, but I was unable to generate a response. Please try again."
except Exception as e:
print(f"Error during generation: {str(e)}")
yield f"An error occurred: {str(e)}"
return predict
def get_image_base64(url: str, ext: str) -> str:
with open(url, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return f"data:image/{ext};base64,{encoded_string}"
def handle_user_msg(message: str) -> str:
if isinstance(message, str):
return message
elif isinstance(message, dict):
if message.get("files"):
ext = os.path.splitext(message["files"][-1])[1].strip(".").lower()
if ext in ["png", "jpg", "jpeg", "gif", "pdf"]:
encoded_str = get_image_base64(message["files"][-1], ext)
return f"{message.get('text', '')}\n![Image]({encoded_str})"
else:
raise NotImplementedError(f"Unsupported file type: {ext}")
else:
return message.get("text", "")
else:
raise NotImplementedError("Unsupported message type")
def get_interface_args(pipeline: str):
if pipeline == "chat":
inputs = None
outputs = None
def preprocess(message, history):
messages = []
files = None
for user_msg, assistant_msg in history:
if assistant_msg is not None:
messages.append({"role": "user", "content": handle_user_msg(user_msg)})
messages.append({"role": "assistant", "content": assistant_msg})
else:
files = user_msg
if isinstance(message, str) and files is not None:
message = {"text": message, "files": files}
elif isinstance(message, dict) and files is not None:
if not message.get("files"):
message["files"] = files
messages.append({"role": "user", "content": handle_user_msg(message)})
return {"messages": messages}
postprocess = lambda x: x # No additional postprocessing needed
else:
raise ValueError(f"Unsupported pipeline type: {pipeline}")
return inputs, outputs, preprocess, postprocess
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
"""Create a Gradio Interface with similar styling and parameters."""
# Retrieving preprocess and postprocess functions
_, _, preprocess, postprocess = get_interface_args("chat")
# Getting the predict function
predict_fn = get_fn(model_name=name, **kwargs)
# Defining a wrapper function that integrates preprocessing and postprocessing
def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
# Preprocessing the inputs
preprocessed = preprocess(message, history)
# Extracting the preprocessed messages
messages = preprocessed["messages"]
# Calling the predict function and generate the response
response_generator = predict_fn(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p
)
# Collecting the generated response
response = ""
for partial_response in response_generator:
response = partial_response # Gradio will handle streaming
yield response
# Creating the Gradio ChatInterface with the wrapper function
interface = gr.ChatInterface(
fn=wrapper,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(
value="You are a helpful AI assistant.",
label="System prompt"
),
gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
gr.Slider(128, 4096, value=1024, label="Max new tokens"),
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
],
)
return interface