Spaces:
Runtime error
Runtime error
import gradio as gr | |
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer | |
import torch, transformers | |
from threading import Thread | |
#Load the model | |
model_id = 'mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq' | |
model = HQQModelForCausalLM.from_quantized(model_id, adapter='adapter_v0.1.lora') | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
#Setup Inference Mode | |
tokenizer.add_bos_token = False | |
tokenizer.add_eos_token = False | |
if not tokenizer.pad_token: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
model.config.use_cache = True | |
model.eval(); | |
# Optional: torch compile for faster inference | |
model = torch.compile(model) | |
def chat_processor(chat, max_new_tokens=100, do_sample=True, device='cpu'): | |
tokenizer.use_default_system_prompt = False | |
streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_params = dict( | |
tokenizer("<s> [INST] " + chat + " [/INST] ", return_tensors="pt").to(device), | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
pad_token_id=tokenizer.pad_token_id, | |
top_p=0.90 if do_sample else None, | |
top_k=50 if do_sample else None, | |
temperature= 0.6 if do_sample else None, | |
num_beams=1, | |
repetition_penalty=1.2, | |
) | |
t = Thread(target=model.generate, kwargs=generate_params) | |
t.start() | |
#print("User: ", chat); | |
#print("Assistant: "); | |
#outputs = "" | |
#for text in streamer: | |
# outputs += text | |
# print(text, end="", flush=True) | |
#torch.cuda.empty_cache() | |
return streamer | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
clear = gr.Button("Clear") | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history): | |
print("Question: ", history[-1][0]) | |
stream = chat_processor(chat=history[-1][0]) | |
history[-1][1] = "" | |
for character in stream: | |
print(character) | |
history[-1][1] += character | |
yield history | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue() | |
demo.launch() |