janbanot's picture
fix:generate
4275621
import gradio as gr
import torch
import spaces
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
from threading import Thread
MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct"
MODEL_NAME = MODEL_ID.split("/")[-1]
if torch.cuda.is_available():
device = torch.device("cuda")
print("Using GPU:", torch.cuda.get_device_name(0))
else:
device = torch.device("cpu")
print("CUDA is not available. Using CPU.")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
@spaces.GPU
def generate(
user_input,
temperature,
max_tokens,
top_k,
repetition_penalty,
top_p,
prompt_style="",
):
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
system = f"""Jesteś pomocnym botem udzielającym odpowiedzi na pytania w języku polskim.
Odpowiadaj krótko i zwięźle, unikaj zbyt skomplikowanych odpowiedzi.
{prompt_style}
"""
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": user_input})
tokenizer_output = tokenizer.apply_chat_template(
messages, return_tensors="pt", return_dict=True
)
if torch.cuda.is_available():
model_input_ids = tokenizer_output.input_ids.to(device)
model_attention_mask = tokenizer_output.attention_mask.to(device)
else:
model_input_ids = tokenizer_output.input_ids
model_attention_mask = tokenizer_output.attention_mask
generate_kwargs = {
"input_ids": model_input_ids,
"attention_mask": model_attention_mask,
"streamer": streamer,
"do_sample": True if temperature else False,
"temperature": temperature,
"max_new_tokens": max_tokens,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"top_p": top_p,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_response = ""
for new_token in streamer:
partial_response += new_token
if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response:
break
# Strip leading whitespace and newlines
cleaned_response = partial_response.lstrip("\n").lstrip()
yield cleaned_response
STYLE_PROMPTS = {
"Formalny": """Przekształć poniższy tekst na bardziej formalny, zachowując jego oryginalne znaczenie i klarowność.""", # noqa
"Nieformalny": """Przekształć poniższy tekst na luźniejszy i bardziej nieformalny, tak żeby brzmiał swobodnie i naturalnie..""", # noqa
"Neutralny": """Przekształć poniższy tekst na bardziej neutralny, eliminując zbyt formalne lub potoczne sformułowania.""", # noqa
}
with gr.Blocks(
css="""
.gradio-container { max-width: 1600px; margin: 20px; padding: 10px; }
#style-dropdown { flex: 3; }
#generate-btn, #clear-btn { flex: 1; max-width: 100px; }
.same-height { height: 60px; }
"""
) as demo:
gr.Markdown("# Bielik Tools - narzędzia dla modelu Bielik v2.3")
with gr.Column(elem_id="main-content"):
with gr.Row():
simple_question_btn = gr.Button("Zadaj Pytanie", variant="primary")
formalizer_btn = gr.Button("Zmiana stylu", variant="secondary")
judge_btn = gr.Button("Sędzia", interactive=False)
# Function to switch tool visibility and update button styles based on the active tool
def switch_tool(tool):
print(f"Switched to {tool}")
return [
gr.Button(variant="primary" if tool == "Formalizer" else "secondary"),
gr.Button(variant="primary" if tool == "Judge" else "secondary"),
gr.Button(
variant="primary" if tool == "Simple Question" else "secondary"
),
gr.update(visible=(tool == "Formalizer")),
gr.update(visible=(tool == "Judge")),
gr.update(visible=(tool == "Simple Question")),
]
# Simple Question content column
with gr.Column(visible=True) as simple_question_column:
input_text_sq = gr.Textbox(
label="Twoje pytanie",
placeholder="Zadaj swoje pytanie tutaj...",
lines=5,
)
with gr.Row():
generate_btn_sq = gr.Button("Generuj odpowiedź", interactive=False)
clear_btn_sq = gr.Button("Wyczyść", interactive=False)
output_text_sq = gr.Textbox(label="Odpowiedź", interactive=False, lines=5)
with gr.Accordion("⚙️ Parametry", open=False):
temperature_sq = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura")
max_tokens_sq = gr.Slider(
128, 4096, 1024, label="Maksymalna długość odpowiedzi"
)
top_k_sq = gr.Slider(1, 80, 40, step=1, label="Top K")
repetition_penalty_sq = gr.Slider(
0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń"
)
top_p_sq = gr.Slider(0, 1, 0.95, step=0.05, label="Top P")
# Update button states based on input and output text changes for interactivity
def update_button_states_sq(input_text, output_text):
return [
gr.update(interactive=bool(input_text)),
gr.update(interactive=bool(input_text)),
gr.update(interactive=bool(input_text or output_text)),
]
input_text_sq.change(
update_button_states_sq,
inputs=[input_text_sq, output_text_sq],
outputs=[generate_btn_sq, clear_btn_sq],
)
output_text_sq.change(
update_button_states_sq,
inputs=[input_text_sq, output_text_sq],
outputs=[generate_btn_sq, clear_btn_sq],
)
# Event handlers for button actions to process and clear text
generate_btn_sq.click(
fn=generate,
inputs=[
input_text_sq,
temperature_sq,
max_tokens_sq,
top_k_sq,
repetition_penalty_sq,
top_p_sq,
],
outputs=output_text_sq,
)
clear_btn_sq.click(
fn=lambda: ("", ""),
inputs=None,
outputs=[input_text_sq, output_text_sq],
)
with gr.Column(visible=False) as formalizer_column:
input_text = gr.Textbox(
placeholder="Wpisz tekst tutaj...", label="Twój tekst", lines=5
)
with gr.Row():
gr.Text(
"Wybierz styl:",
elem_id="style-label",
show_label=False,
elem_classes="same-height",
)
style_dropdown = gr.Dropdown(
choices=["Formalny", "Nieformalny", "Neutralny"],
value="Neutralny", # Set a default value
elem_id="style-dropdown",
show_label=False,
elem_classes="same-height",
)
generate_btn = gr.Button(
"Generuj",
interactive=False,
elem_id="generate-btn",
elem_classes="same-height",
)
clear_btn = gr.Button(
"Wyczyść",
interactive=False,
elem_id="clear-btn",
elem_classes="same-height",
)
output_text = gr.Textbox(label="Wynik", interactive=False, lines=5)
# Update button states based on input and output text changes for interactivity
def update_button_states(input_text, output_text):
return [
gr.update(interactive=bool(input_text)),
gr.update(interactive=bool(input_text or output_text)),
gr.update(interactive=bool(output_text)),
]
input_text.change(
update_button_states,
inputs=[input_text, output_text],
outputs=[generate_btn, clear_btn],
)
output_text.change(
update_button_states,
inputs=[input_text, output_text],
outputs=[generate_btn, clear_btn],
)
# Event handlers for button actions to process and clear text
def format_with_style(text, style):
partial_text = ""
for chunk in generate(
text,
temperature=0.3,
max_tokens=1024,
top_k=40,
repetition_penalty=1.1,
top_p=0.95,
prompt_style=STYLE_PROMPTS[style]
):
partial_text = chunk
yield partial_text
generate_btn.click(
fn=format_with_style,
inputs=[input_text, style_dropdown],
outputs=output_text,
)
clear_btn.click(
fn=lambda: ("", ""), inputs=None, outputs=[input_text, output_text]
)
# Placeholder for Judge content column, initially hidden
with gr.Column(visible=False) as judge_column:
gr.Markdown("Judge tool content goes here.")
with gr.Accordion("⚙️ Parametry", open=False):
temperature_jg = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura")
max_tokens_jg = gr.Slider(
128, 4096, 1024, label="Maksymalna długość odpowiedzi"
)
top_k_jg = gr.Slider(1, 80, 40, step=1, label="Top K")
repetition_penalty_jg = gr.Slider(
0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń"
)
top_p_jg = gr.Slider(0, 1, 0.95, step=0.05, label="Top P")
formalizer_btn.click(
lambda: switch_tool("Formalizer"),
outputs=[
formalizer_btn,
judge_btn,
simple_question_btn,
formalizer_column,
judge_column,
simple_question_column,
],
)
judge_btn.click(
lambda: switch_tool("Judge"),
outputs=[
formalizer_btn,
judge_btn,
simple_question_btn,
formalizer_column,
judge_column,
simple_question_column,
],
)
simple_question_btn.click(
lambda: switch_tool("Simple Question"),
outputs=[
formalizer_btn,
judge_btn,
simple_question_btn,
formalizer_column,
judge_column,
simple_question_column,
],
)
formalizer_btn.click(
lambda: switch_tool("Formalizer"),
outputs=[
formalizer_btn,
judge_btn,
simple_question_btn,
formalizer_column,
judge_column,
simple_question_column,
],
)
judge_btn.click(
lambda: switch_tool("Judge"),
outputs=[
formalizer_btn,
judge_btn,
simple_question_btn,
formalizer_column,
judge_column,
simple_question_column,
],
)
simple_question_btn.click(
lambda: switch_tool("Simple Question"),
outputs=[
formalizer_btn,
judge_btn,
simple_question_btn,
formalizer_column,
judge_column,
simple_question_column,
],
)
demo.queue().launch()