Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import spaces | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
TextIteratorStreamer, | |
) | |
from threading import Thread | |
MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct" | |
MODEL_NAME = MODEL_ID.split("/")[-1] | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print("Using GPU:", torch.cuda.get_device_name(0)) | |
else: | |
device = torch.device("cpu") | |
print("CUDA is not available. Using CPU.") | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16, | |
quantization_config=quantization_config, | |
low_cpu_mem_usage=True, | |
) | |
def generate( | |
user_input, | |
temperature, | |
max_tokens, | |
top_k, | |
repetition_penalty, | |
top_p, | |
prompt_style="", | |
): | |
streamer = TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
system = f"""Jesteś pomocnym botem udzielającym odpowiedzi na pytania w języku polskim. | |
Odpowiadaj krótko i zwięźle, unikaj zbyt skomplikowanych odpowiedzi. | |
{prompt_style} | |
""" | |
messages = [] | |
if system: | |
messages.append({"role": "system", "content": system}) | |
messages.append({"role": "user", "content": user_input}) | |
tokenizer_output = tokenizer.apply_chat_template( | |
messages, return_tensors="pt", return_dict=True | |
) | |
if torch.cuda.is_available(): | |
model_input_ids = tokenizer_output.input_ids.to(device) | |
model_attention_mask = tokenizer_output.attention_mask.to(device) | |
else: | |
model_input_ids = tokenizer_output.input_ids | |
model_attention_mask = tokenizer_output.attention_mask | |
generate_kwargs = { | |
"input_ids": model_input_ids, | |
"attention_mask": model_attention_mask, | |
"streamer": streamer, | |
"do_sample": True if temperature else False, | |
"temperature": temperature, | |
"max_new_tokens": max_tokens, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty, | |
"top_p": top_p, | |
} | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_response = "" | |
for new_token in streamer: | |
partial_response += new_token | |
if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response: | |
break | |
# Strip leading whitespace and newlines | |
cleaned_response = partial_response.lstrip("\n").lstrip() | |
yield cleaned_response | |
STYLE_PROMPTS = { | |
"Formalny": """Przekształć poniższy tekst na bardziej formalny, zachowując jego oryginalne znaczenie i klarowność.""", # noqa | |
"Nieformalny": """Przekształć poniższy tekst na luźniejszy i bardziej nieformalny, tak żeby brzmiał swobodnie i naturalnie..""", # noqa | |
"Neutralny": """Przekształć poniższy tekst na bardziej neutralny, eliminując zbyt formalne lub potoczne sformułowania.""", # noqa | |
} | |
with gr.Blocks( | |
css=""" | |
.gradio-container { max-width: 1600px; margin: 20px; padding: 10px; } | |
#style-dropdown { flex: 3; } | |
#generate-btn, #clear-btn { flex: 1; max-width: 100px; } | |
.same-height { height: 60px; } | |
""" | |
) as demo: | |
gr.Markdown("# Bielik Tools - narzędzia dla modelu Bielik v2.3") | |
with gr.Column(elem_id="main-content"): | |
with gr.Row(): | |
simple_question_btn = gr.Button("Zadaj Pytanie", variant="primary") | |
formalizer_btn = gr.Button("Zmiana stylu", variant="secondary") | |
judge_btn = gr.Button("Sędzia", interactive=False) | |
# Function to switch tool visibility and update button styles based on the active tool | |
def switch_tool(tool): | |
print(f"Switched to {tool}") | |
return [ | |
gr.Button(variant="primary" if tool == "Formalizer" else "secondary"), | |
gr.Button(variant="primary" if tool == "Judge" else "secondary"), | |
gr.Button( | |
variant="primary" if tool == "Simple Question" else "secondary" | |
), | |
gr.update(visible=(tool == "Formalizer")), | |
gr.update(visible=(tool == "Judge")), | |
gr.update(visible=(tool == "Simple Question")), | |
] | |
# Simple Question content column | |
with gr.Column(visible=True) as simple_question_column: | |
input_text_sq = gr.Textbox( | |
label="Twoje pytanie", | |
placeholder="Zadaj swoje pytanie tutaj...", | |
lines=5, | |
) | |
with gr.Row(): | |
generate_btn_sq = gr.Button("Generuj odpowiedź", interactive=False) | |
clear_btn_sq = gr.Button("Wyczyść", interactive=False) | |
output_text_sq = gr.Textbox(label="Odpowiedź", interactive=False, lines=5) | |
with gr.Accordion("⚙️ Parametry", open=False): | |
temperature_sq = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura") | |
max_tokens_sq = gr.Slider( | |
128, 4096, 1024, label="Maksymalna długość odpowiedzi" | |
) | |
top_k_sq = gr.Slider(1, 80, 40, step=1, label="Top K") | |
repetition_penalty_sq = gr.Slider( | |
0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń" | |
) | |
top_p_sq = gr.Slider(0, 1, 0.95, step=0.05, label="Top P") | |
# Update button states based on input and output text changes for interactivity | |
def update_button_states_sq(input_text, output_text): | |
return [ | |
gr.update(interactive=bool(input_text)), | |
gr.update(interactive=bool(input_text)), | |
gr.update(interactive=bool(input_text or output_text)), | |
] | |
input_text_sq.change( | |
update_button_states_sq, | |
inputs=[input_text_sq, output_text_sq], | |
outputs=[generate_btn_sq, clear_btn_sq], | |
) | |
output_text_sq.change( | |
update_button_states_sq, | |
inputs=[input_text_sq, output_text_sq], | |
outputs=[generate_btn_sq, clear_btn_sq], | |
) | |
# Event handlers for button actions to process and clear text | |
generate_btn_sq.click( | |
fn=generate, | |
inputs=[ | |
input_text_sq, | |
temperature_sq, | |
max_tokens_sq, | |
top_k_sq, | |
repetition_penalty_sq, | |
top_p_sq, | |
], | |
outputs=output_text_sq, | |
) | |
clear_btn_sq.click( | |
fn=lambda: ("", ""), | |
inputs=None, | |
outputs=[input_text_sq, output_text_sq], | |
) | |
with gr.Column(visible=False) as formalizer_column: | |
input_text = gr.Textbox( | |
placeholder="Wpisz tekst tutaj...", label="Twój tekst", lines=5 | |
) | |
with gr.Row(): | |
gr.Text( | |
"Wybierz styl:", | |
elem_id="style-label", | |
show_label=False, | |
elem_classes="same-height", | |
) | |
style_dropdown = gr.Dropdown( | |
choices=["Formalny", "Nieformalny", "Neutralny"], | |
value="Neutralny", # Set a default value | |
elem_id="style-dropdown", | |
show_label=False, | |
elem_classes="same-height", | |
) | |
generate_btn = gr.Button( | |
"Generuj", | |
interactive=False, | |
elem_id="generate-btn", | |
elem_classes="same-height", | |
) | |
clear_btn = gr.Button( | |
"Wyczyść", | |
interactive=False, | |
elem_id="clear-btn", | |
elem_classes="same-height", | |
) | |
output_text = gr.Textbox(label="Wynik", interactive=False, lines=5) | |
# Update button states based on input and output text changes for interactivity | |
def update_button_states(input_text, output_text): | |
return [ | |
gr.update(interactive=bool(input_text)), | |
gr.update(interactive=bool(input_text or output_text)), | |
gr.update(interactive=bool(output_text)), | |
] | |
input_text.change( | |
update_button_states, | |
inputs=[input_text, output_text], | |
outputs=[generate_btn, clear_btn], | |
) | |
output_text.change( | |
update_button_states, | |
inputs=[input_text, output_text], | |
outputs=[generate_btn, clear_btn], | |
) | |
# Event handlers for button actions to process and clear text | |
def format_with_style(text, style): | |
partial_text = "" | |
for chunk in generate( | |
text, | |
temperature=0.3, | |
max_tokens=1024, | |
top_k=40, | |
repetition_penalty=1.1, | |
top_p=0.95, | |
prompt_style=STYLE_PROMPTS[style] | |
): | |
partial_text = chunk | |
yield partial_text | |
generate_btn.click( | |
fn=format_with_style, | |
inputs=[input_text, style_dropdown], | |
outputs=output_text, | |
) | |
clear_btn.click( | |
fn=lambda: ("", ""), inputs=None, outputs=[input_text, output_text] | |
) | |
# Placeholder for Judge content column, initially hidden | |
with gr.Column(visible=False) as judge_column: | |
gr.Markdown("Judge tool content goes here.") | |
with gr.Accordion("⚙️ Parametry", open=False): | |
temperature_jg = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura") | |
max_tokens_jg = gr.Slider( | |
128, 4096, 1024, label="Maksymalna długość odpowiedzi" | |
) | |
top_k_jg = gr.Slider(1, 80, 40, step=1, label="Top K") | |
repetition_penalty_jg = gr.Slider( | |
0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń" | |
) | |
top_p_jg = gr.Slider(0, 1, 0.95, step=0.05, label="Top P") | |
formalizer_btn.click( | |
lambda: switch_tool("Formalizer"), | |
outputs=[ | |
formalizer_btn, | |
judge_btn, | |
simple_question_btn, | |
formalizer_column, | |
judge_column, | |
simple_question_column, | |
], | |
) | |
judge_btn.click( | |
lambda: switch_tool("Judge"), | |
outputs=[ | |
formalizer_btn, | |
judge_btn, | |
simple_question_btn, | |
formalizer_column, | |
judge_column, | |
simple_question_column, | |
], | |
) | |
simple_question_btn.click( | |
lambda: switch_tool("Simple Question"), | |
outputs=[ | |
formalizer_btn, | |
judge_btn, | |
simple_question_btn, | |
formalizer_column, | |
judge_column, | |
simple_question_column, | |
], | |
) | |
formalizer_btn.click( | |
lambda: switch_tool("Formalizer"), | |
outputs=[ | |
formalizer_btn, | |
judge_btn, | |
simple_question_btn, | |
formalizer_column, | |
judge_column, | |
simple_question_column, | |
], | |
) | |
judge_btn.click( | |
lambda: switch_tool("Judge"), | |
outputs=[ | |
formalizer_btn, | |
judge_btn, | |
simple_question_btn, | |
formalizer_column, | |
judge_column, | |
simple_question_column, | |
], | |
) | |
simple_question_btn.click( | |
lambda: switch_tool("Simple Question"), | |
outputs=[ | |
formalizer_btn, | |
judge_btn, | |
simple_question_btn, | |
formalizer_column, | |
judge_column, | |
simple_question_column, | |
], | |
) | |
demo.queue().launch() | |