janbanot's picture
fix: try again
524b722
raw
history blame
5.66 kB
import os
import subprocess
from threading import Thread
import random
import torch
import spaces
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
MODEL_ID = "speakleash/Bielik-7B-Instruct-v0.1"
CHAT_TEMPLATE = "ChatML"
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 1024
COLOR = os.environ.get("COLOR")
EMOJI = os.environ.get("EMOJI")
DESCRIPTION = os.environ.get("DESCRIPTION")
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
attn_implementation="flash_attention_2",
)
@spaces.GPU()
def generate(
instruction,
stop_tokens,
temperature,
max_new_tokens,
top_k,
repetition_penalty,
top_p,
):
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
{
"input_ids": input_ids.to(device),
"attention_mask": attention_mask.to(device),
},
streamer=streamer,
do_sample=True if temperature else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
def predict(
message,
history,
system_prompt,
temperature,
max_new_tokens,
top_k,
repetition_penalty,
top_p,
):
repetition_penalty = float(repetition_penalty)
print(
"LLL",
[
message,
history,
system_prompt,
temperature,
max_new_tokens,
top_k,
repetition_penalty,
top_p,
],
)
# Format history with a given chat template
if CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
instruction = "<|im_start|>system\n" + system_prompt + "\n<|im_end|>\n"
for human, assistant in history:
instruction += (
"<|im_start|>user\n"
+ human
+ "\n<|im_end|>\n<|im_start|>assistant\n"
+ assistant
)
instruction += (
"\n<|im_start|>user\n" + message + "\n<|im_end|>\n<|im_start|>assistant\n"
)
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
instruction = "<s>[INST] " + system_prompt
for human, assistant in history:
instruction += human + " [/INST] " + assistant + "</s>[INST]"
instruction += " " + message + " [/INST]"
elif CHAT_TEMPLATE == "Bielik":
stop_tokens = ["</s>"]
prompt_builder = ["<s>[INST] "]
if system_prompt:
prompt_builder.append(f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n")
for human, assistant in history:
prompt_builder.append(f"{human} [/INST] {assistant}</s>[INST] ")
prompt_builder.append(f"{message} [/INST]")
instruction = "".join(prompt_builder)
else:
raise Exception(
"Incorrect chat template, select 'ChatML' or 'Mistral Instruct'"
)
print(instruction)
for output_text in generate(
instruction,
stop_tokens,
temperature,
max_new_tokens,
top_k,
repetition_penalty,
top_p,
):
yield output_text
# Create Gradio interface
def update_examples():
exs = [["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]]
random.shuffle(exs)
return gr.Dataset(samples=exs)
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Chatbot", likeable=True, render=False)
chat = gr.ChatInterface(
predict,
chatbot=chatbot,
title=EMOJI + " " + MODEL_NAME + " - online chat demo",
description=DESCRIPTION,
examples=[["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."]],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Textbox("", label="System prompt", render=False),
gr.Slider(0, 1, 0.6, label="Temperature", render=False),
gr.Slider(128, 4096, 1024, label="Max new tokens", render=False),
gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False),
gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False),
gr.Slider(0, 1, 0.95, label="Top P sampling", render=False),
],
theme=gr.themes.Soft(primary_hue=COLOR),
)
demo.load(update_examples, None, chat.examples_handler.dataset)
demo.queue(max_size=20).launch()