Spaces:
Sleeping
Sleeping
import os | |
import subprocess | |
from threading import Thread | |
import random | |
import torch | |
import spaces | |
import gradio as gr | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
TextIteratorStreamer, | |
) | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
MODEL_ID = "speakleash/Bielik-7B-Instruct-v0.1" | |
CHAT_TEMPLATE = "ChatML" | |
MODEL_NAME = MODEL_ID.split("/")[-1] | |
CONTEXT_LENGTH = 1024 | |
COLOR = os.environ.get("COLOR") | |
EMOJI = os.environ.get("EMOJI") | |
DESCRIPTION = os.environ.get("DESCRIPTION") | |
# Load model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", | |
torch_dtype="auto", | |
attn_implementation="flash_attention_2", | |
) | |
def generate( | |
instruction, | |
stop_tokens, | |
temperature, | |
max_new_tokens, | |
top_k, | |
repetition_penalty, | |
top_p, | |
): | |
streamer = TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True) | |
input_ids, attention_mask = enc.input_ids, enc.attention_mask | |
if input_ids.shape[1] > CONTEXT_LENGTH: | |
input_ids = input_ids[:, -CONTEXT_LENGTH:] | |
generate_kwargs = dict( | |
{ | |
"input_ids": input_ids.to(device), | |
"attention_mask": attention_mask.to(device), | |
}, | |
streamer=streamer, | |
do_sample=True if temperature else False, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
top_p=top_p, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for new_token in streamer: | |
outputs.append(new_token) | |
if new_token in stop_tokens: | |
break | |
yield "".join(outputs) | |
def predict( | |
message, | |
history, | |
system_prompt, | |
temperature, | |
max_new_tokens, | |
top_k, | |
repetition_penalty, | |
top_p, | |
): | |
repetition_penalty = float(repetition_penalty) | |
print( | |
"LLL", | |
[ | |
message, | |
history, | |
system_prompt, | |
temperature, | |
max_new_tokens, | |
top_k, | |
repetition_penalty, | |
top_p, | |
], | |
) | |
# Format history with a given chat template | |
if CHAT_TEMPLATE == "ChatML": | |
stop_tokens = ["<|endoftext|>", "<|im_end|>"] | |
instruction = "<|im_start|>system\n" + system_prompt + "\n<|im_end|>\n" | |
for human, assistant in history: | |
instruction += ( | |
"<|im_start|>user\n" | |
+ human | |
+ "\n<|im_end|>\n<|im_start|>assistant\n" | |
+ assistant | |
) | |
instruction += ( | |
"\n<|im_start|>user\n" + message + "\n<|im_end|>\n<|im_start|>assistant\n" | |
) | |
elif CHAT_TEMPLATE == "Mistral Instruct": | |
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "] | |
instruction = "<s>[INST] " + system_prompt | |
for human, assistant in history: | |
instruction += human + " [/INST] " + assistant + "</s>[INST]" | |
instruction += " " + message + " [/INST]" | |
elif CHAT_TEMPLATE == "Bielik": | |
stop_tokens = ["</s>"] | |
prompt_builder = ["<s>[INST] "] | |
if system_prompt: | |
prompt_builder.append(f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n") | |
for human, assistant in history: | |
prompt_builder.append(f"{human} [/INST] {assistant}</s>[INST] ") | |
prompt_builder.append(f"{message} [/INST]") | |
instruction = "".join(prompt_builder) | |
else: | |
raise Exception( | |
"Incorrect chat template, select 'ChatML' or 'Mistral Instruct'" | |
) | |
print(instruction) | |
for output_text in generate( | |
instruction, | |
stop_tokens, | |
temperature, | |
max_new_tokens, | |
top_k, | |
repetition_penalty, | |
top_p, | |
): | |
yield output_text | |
# Create Gradio interface | |
def update_examples(): | |
exs = [["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]] | |
random.shuffle(exs) | |
return gr.Dataset(samples=exs) | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot(label="Chatbot", likeable=True, render=False) | |
chat = gr.ChatInterface( | |
predict, | |
chatbot=chatbot, | |
title=EMOJI + " " + MODEL_NAME + " - online chat demo", | |
description=DESCRIPTION, | |
examples=[["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]], | |
additional_inputs_accordion=gr.Accordion( | |
label="⚙️ Parameters", open=False, render=False | |
), | |
additional_inputs=[ | |
gr.Textbox("", label="System prompt", render=False), | |
gr.Slider(0, 1, 0.6, label="Temperature", render=False), | |
gr.Slider(128, 4096, 1024, label="Max new tokens", render=False), | |
gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False), | |
gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False), | |
gr.Slider(0, 1, 0.95, label="Top P sampling", render=False), | |
], | |
theme=gr.themes.Soft(primary_hue=COLOR), | |
) | |
demo.load(update_examples, None, chat.examples_handler.dataset) | |
demo.queue(max_size=20).launch() | |