llama-hqq-1-bit / app.py
chrispie's picture
fixed gradio error on completed stream
d676cb8 verified
raw
history blame
2.01 kB
import gradio as gr
from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
import torch, transformers
from threading import Thread
import time
#Load the model
model_id = 'mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq'
model = HQQModelForCausalLM.from_quantized(model_id, adapter='adapter_v0.1.lora', device='cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id)
#Setup Inference Mode
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
if not tokenizer.pad_token: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.config.use_cache = True
model.eval();
# Optional: torch compile for faster inference
model = torch.compile(model)
def chat_processor(chat, max_new_tokens=100, do_sample=True, device='cuda'):
tokenizer.use_default_system_prompt = False
streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_params = dict(
tokenizer("<s> [INST] " + chat + " [/INST] ", return_tensors="pt").to(device),
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
top_p=0.90 if do_sample else None,
top_k=50 if do_sample else None,
temperature= 0.6 if do_sample else None,
num_beams=1,
repetition_penalty=1.2,
)
t = Thread(target=model.generate, kwargs=generate_params)
t.start()
#print("User: ", chat);
#print("Assistant: ");
#outputs = ""
#for text in streamer:
# outputs += text
# print(text, end="", flush=True)
#torch.cuda.empty_cache()
return t, streamer
def chat(message, history):
t, stream = chat_processor(chat=message)
response = ""
for character in stream:
if character is not None:
response += character
# print(character)
yield response
time.sleep(0.1)
t.join()
torch.cuda.empty_cache()
gr.ChatInterface(chat).launch()