Spaces:
Sleeping
Sleeping
import os | |
import subprocess | |
from huggingface_hub import hf_hub_download, list_repo_files | |
import gradio as gr | |
from typing import Callable | |
import base64 | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
def get_fn(model_path: str, **model_kwargs): | |
"""Create a chat function with the specified model.""" | |
# Initialize tokenizer and model | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# Simple flash-attention installation attempt | |
try: | |
subprocess.run( | |
'pip install flash-attn --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True, | |
check=True | |
) | |
# Try loading model with flash attention | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
quantization_config=quantization_config, | |
attn_implementation="flash_attention_2", | |
) | |
except Exception as e: | |
print(f"Flash Attention failed, falling back to default attention: {str(e)}") | |
# Fallback to default attention implementation | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype= torch.bfloat16 | |
device_map="auto", | |
quantization_config=quantization_config, | |
) | |
def predict( | |
message: str, | |
history, | |
system_prompt: str, | |
temperature: float, | |
max_new_tokens: int, | |
top_k: int, | |
repetition_penalty: float, | |
top_p: float | |
): | |
try: | |
# Format conversation with ChatML format | |
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n' | |
for user_msg, assistant_msg in history: | |
instruction += f'<|im_start|>user\n{user_msg}\n<|im_end|>\n<|im_start|>assistant\n{assistant_msg}\n<|im_end|>\n' | |
instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n' | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True) | |
input_ids, attention_mask = enc.input_ids, enc.attention_mask | |
# Truncate if needed | |
if input_ids.shape[1] > 8192: # Using n_ctx from original | |
input_ids = input_ids[:, -8192:] | |
attention_mask = attention_mask[:, -8192:] | |
generate_kwargs = dict( | |
input_ids=input_ids.to(device), | |
attention_mask=attention_mask.to(device), | |
streamer=streamer, | |
do_sample=True, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
top_p=top_p | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
response_text = "" | |
for new_token in streamer: | |
if new_token in ["<|endoftext|>", "<|im_end|>"]: | |
break | |
response_text += new_token | |
yield response_text.strip() | |
if not response_text.strip(): | |
yield "I apologize, but I was unable to generate a response. Please try again." | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
yield f"An error occurred: {str(e)}" | |
return predict | |
def get_image_base64(url: str, ext: str): | |
with open(url, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
return "data:image/" + ext + ";base64," + encoded_string | |
def handle_user_msg(message: str): | |
if type(message) is str: | |
return message | |
elif type(message) is dict: | |
if message["files"] is not None and len(message["files"]) > 0: | |
ext = os.path.splitext(message["files"][-1])[1].strip(".") | |
if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: | |
encoded_str = get_image_base64(message["files"][-1], ext) | |
else: | |
raise NotImplementedError(f"Not supported file type {ext}") | |
content = [ | |
{"type": "text", "text": message["text"]}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": encoded_str, | |
} | |
}, | |
] | |
else: | |
content = message["text"] | |
return content | |
else: | |
raise NotImplementedError | |
def get_interface_args(pipeline): | |
if pipeline == "chat": | |
inputs = None | |
outputs = None | |
def preprocess(message, history): | |
messages = [] | |
files = None | |
for user_msg, assistant_msg in history: | |
if assistant_msg is not None: | |
messages.append({"role": "user", "content": handle_user_msg(user_msg)}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
else: | |
files = user_msg | |
if type(message) is str and files is not None: | |
message = {"text":message, "files":files} | |
elif type(message) is dict and files is not None: | |
if message["files"] is None or len(message["files"]) == 0: | |
message["files"] = files | |
messages.append({"role": "user", "content": handle_user_msg(message)}) | |
return {"messages": messages} | |
postprocess = lambda x: x | |
else: | |
# Add other pipeline types when they will be needed | |
raise ValueError(f"Unsupported pipeline type: {pipeline}") | |
return inputs, outputs, preprocess, postprocess | |
def get_pipeline(model_name): | |
# Determine the pipeline type based on the model name | |
# For simplicity, assuming all models are chat models at the moment | |
return "chat" | |
def get_model_path(name: str = None, model_path: str = None) -> str: | |
"""Get the local path to the model.""" | |
if model_path: | |
return model_path | |
if name: | |
if "/" in name: | |
return name # Return HF model ID directly | |
else: | |
# You could maintain a mapping of friendly names to HF model IDs | |
model_mapping = { | |
# Add any default model mappings here | |
"example-model": "organization/model-name" | |
} | |
if name not in model_mapping: | |
raise ValueError(f"Unknown model name: {name}") | |
return model_mapping[name] | |
raise ValueError("Either name or model_path must be provided") | |
def registry(name: str = None, model_path: str = None, **kwargs): | |
"""Create a Gradio Interface with similar styling and parameters.""" | |
model_path = get_model_path(name, model_path) | |
fn = get_fn(model_path, **kwargs) | |
interface = gr.ChatInterface( | |
fn=fn, | |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False), | |
additional_inputs=[ | |
gr.Textbox( | |
"You are a helpful AI assistant.", | |
label="System prompt" | |
), | |
gr.Slider(0, 1, 0.7, label="Temperature"), | |
gr.Slider(128, 4096, 1024, label="Max new tokens"), | |
gr.Slider(1, 80, 40, label="Top K sampling"), | |
gr.Slider(0, 2, 1.1, label="Repetition penalty"), | |
gr.Slider(0, 1, 0.95, label="Top P sampling"), | |
], | |
) | |
return interface |