Spaces:
Sleeping
Sleeping
File size: 5,684 Bytes
87c4b82 543fed2 87c4b82 ee40bdf 87c4b82 460a4a6 543fed2 6a93de9 543fed2 6a93de9 460a4a6 6a93de9 543fed2 460a4a6 87c4b82 460a4a6 87c4b82 7e72b19 6a93de9 543fed2 87c4b82 543fed2 ccfb364 543fed2 ee40bdf 87c4b82 2ebb338 543fed2 ee40bdf 0663556 87c4b82 612a10c b5fc8ee 40abfbf 5747d32 543fed2 6a93de9 612a10c 87c4b82 6a93de9 612a10c 543fed2 612a10c 87c4b82 612a10c 6a93de9 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 612a10c 87c4b82 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 543fed2 6a93de9 460a4a6 543fed2 460a4a6 6bf705e 543fed2 460a4a6 fdb64df 460a4a6 543fed2 460a4a6 543fed2 460a4a6 543fed2 460a4a6 543fed2 62124a7 543fed2 460a4a6 87c4b82 543fed2 87c4b82 543fed2 87c4b82 543fed2 87c4b82 6a93de9 543fed2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import gradio as gr
from typing import Callable, Generator
import base64
from openai import OpenAI
END_POINT = os.environ.get("ENDPOINT")
SECRET_KEY = os.environ.get("SECRETKEY")
USERS = os.environ.get("USERS")
PWD = os.environ.get("PWD")
def get_fn(model_name: str, **model_kwargs) -> Callable:
"""Create a chat function with the specified model."""
# Instantiate an OpenAI client for a custom endpoint
try:
client = OpenAI(
base_url=END_POINT,
api_key=SECRET_KEY,
)
except Exception as e:
print(f"The API or base URL were not defined: {str(e)}")
raise e
def predict(
messages: list,
temperature: float,
max_tokens: int,
top_p: float
) -> Generator[str, None, None]:
try:
# Call the OpenAI API with the formatted messages
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
response_format={"type": "text"},
)
response_text = ""
for chunk in response:
if len(chunk.choices[0].delta.content) > 0:
content = chunk.choices[0].delta.content
if content:
response_text += content
yield response_text.strip()
if not response_text.strip():
yield "I apologize, but I was unable to generate a response. Please try again."
except Exception as e:
print(f"Error during generation: {str(e)}")
yield f"An error occurred: {str(e)}"
return predict
def get_image_base64(url: str, ext: str) -> str:
with open(url, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return f"data:image/{ext};base64,{encoded_string}"
def handle_user_msg(message: str) -> str:
if isinstance(message, str):
return message
elif isinstance(message, dict):
if message.get("files"):
ext = os.path.splitext(message["files"][-1])[1].strip(".").lower()
if ext in ["png", "jpg", "jpeg", "gif", "pdf"]:
encoded_str = get_image_base64(message["files"][-1], ext)
return f"{message.get('text', '')}\n"
else:
raise NotImplementedError(f"Unsupported file type: {ext}")
else:
return message.get("text", "")
else:
raise NotImplementedError("Unsupported message type")
def get_interface_args(pipeline: str):
if pipeline == "chat":
inputs = None
outputs = None
def preprocess(message, history):
messages = []
files = None
for user_msg, assistant_msg in history:
if assistant_msg is not None:
messages.append({"role": "user", "content": handle_user_msg(user_msg)})
messages.append({"role": "assistant", "content": assistant_msg})
else:
files = user_msg
if isinstance(message, str) and files is not None:
message = {"text": message, "files": files}
elif isinstance(message, dict) and files is not None:
if not message.get("files"):
message["files"] = files
messages.append({"role": "user", "content": handle_user_msg(message)})
return {"messages": messages}
postprocess = lambda x: x # No additional postprocessing needed
else:
raise ValueError(f"Unsupported pipeline type: {pipeline}")
return inputs, outputs, preprocess, postprocess
def registry(name: str = None, **kwargs) -> gr.ChatInterface:
"""Create a Gradio Interface with similar styling and parameters."""
# Retrieving preprocess and postprocess functions
_, _, preprocess, postprocess = get_interface_args("chat")
# Getting the predict function
predict_fn = get_fn(model_name=name, **kwargs)
# Defining a wrapper function that integrates preprocessing and postprocessing
def wrapper(message, history, system_prompt, temperature, max_tokens, top_p):
# Preprocessing the inputs
preprocessed = preprocess(message, history)
# Extracting the preprocessed messages
messages = preprocessed["messages"]
# Calling the predict function and generate the response
response_generator = predict_fn(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p
)
# Collecting the generated response
response = ""
for partial_response in response_generator:
response = partial_response # Gradio will handle streaming
yield response
# Creating the Gradio ChatInterface with the wrapper function
interface = gr.ChatInterface(
fn=wrapper,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(
value="You are a helpful AI assistant.",
label="System prompt"
),
gr.Slider(0.0, 1.0, value=0.7, label="Temperature"),
gr.Slider(128, 4096, value=1024, label="Max new tokens"),
gr.Slider(0.0, 1.0, value=0.95, label="Top P sampling"),
],
)
return interface
|