import gradio as gr import torch import spaces from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) from threading import Thread MODEL_ID = "speakleash/Bielik-11B-v2.3-Instruct" MODEL_NAME = MODEL_ID.split("/")[-1] if torch.cuda.is_available(): device = torch.device("cuda") print("Using GPU:", torch.cuda.get_device_name(0)) else: device = torch.device("cpu") print("CUDA is not available. Using CPU.") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, quantization_config=quantization_config, low_cpu_mem_usage=True, ) @spaces.GPU def generate( user_input, temperature, max_tokens, top_k, repetition_penalty, top_p, prompt_style="", ): streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) system = f"""Jesteś pomocnym botem udzielającym odpowiedzi na pytania w języku polskim. Odpowiadaj krótko i zwięźle, unikaj zbyt skomplikowanych odpowiedzi. {prompt_style} """ messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": user_input}) tokenizer_output = tokenizer.apply_chat_template( messages, return_tensors="pt", return_dict=True ) if torch.cuda.is_available(): model_input_ids = tokenizer_output.input_ids.to(device) model_attention_mask = tokenizer_output.attention_mask.to(device) else: model_input_ids = tokenizer_output.input_ids model_attention_mask = tokenizer_output.attention_mask generate_kwargs = { "input_ids": model_input_ids, "attention_mask": model_attention_mask, "streamer": streamer, "do_sample": True if temperature else False, "temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "repetition_penalty": repetition_penalty, "top_p": top_p, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_response = "" for new_token in streamer: partial_response += new_token if "<|im_end|>" in partial_response or "<|endoftext|>" in partial_response: break # Strip leading whitespace and newlines cleaned_response = partial_response.lstrip("\n").lstrip() yield cleaned_response STYLE_PROMPTS = { "Formalny": """Przekształć poniższy tekst na bardziej formalny, zachowując jego oryginalne znaczenie i klarowność.""", # noqa "Nieformalny": """Przekształć poniższy tekst na luźniejszy i bardziej nieformalny, tak żeby brzmiał swobodnie i naturalnie..""", # noqa "Neutralny": """Przekształć poniższy tekst na bardziej neutralny, eliminując zbyt formalne lub potoczne sformułowania.""", # noqa } with gr.Blocks( css=""" .gradio-container { max-width: 1600px; margin: 20px; padding: 10px; } #style-dropdown { flex: 3; } #generate-btn, #clear-btn { flex: 1; max-width: 100px; } .same-height { height: 60px; } """ ) as demo: gr.Markdown("# Bielik Tools - narzędzia dla modelu Bielik v2.3") with gr.Column(elem_id="main-content"): with gr.Row(): simple_question_btn = gr.Button("Zadaj Pytanie", variant="primary") formalizer_btn = gr.Button("Zmiana stylu", variant="secondary") judge_btn = gr.Button("Sędzia", interactive=False) # Function to switch tool visibility and update button styles based on the active tool def switch_tool(tool): print(f"Switched to {tool}") return [ gr.Button(variant="primary" if tool == "Formalizer" else "secondary"), gr.Button(variant="primary" if tool == "Judge" else "secondary"), gr.Button( variant="primary" if tool == "Simple Question" else "secondary" ), gr.update(visible=(tool == "Formalizer")), gr.update(visible=(tool == "Judge")), gr.update(visible=(tool == "Simple Question")), ] # Simple Question content column with gr.Column(visible=True) as simple_question_column: input_text_sq = gr.Textbox( label="Twoje pytanie", placeholder="Zadaj swoje pytanie tutaj...", lines=5, ) with gr.Row(): generate_btn_sq = gr.Button("Generuj odpowiedź", interactive=False) clear_btn_sq = gr.Button("Wyczyść", interactive=False) output_text_sq = gr.Textbox(label="Odpowiedź", interactive=False, lines=5) with gr.Accordion("⚙️ Parametry", open=False): temperature_sq = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura") max_tokens_sq = gr.Slider( 128, 4096, 1024, label="Maksymalna długość odpowiedzi" ) top_k_sq = gr.Slider(1, 80, 40, step=1, label="Top K") repetition_penalty_sq = gr.Slider( 0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń" ) top_p_sq = gr.Slider(0, 1, 0.95, step=0.05, label="Top P") # Update button states based on input and output text changes for interactivity def update_button_states_sq(input_text, output_text): return [ gr.update(interactive=bool(input_text)), gr.update(interactive=bool(input_text)), gr.update(interactive=bool(input_text or output_text)), ] input_text_sq.change( update_button_states_sq, inputs=[input_text_sq, output_text_sq], outputs=[generate_btn_sq, clear_btn_sq], ) output_text_sq.change( update_button_states_sq, inputs=[input_text_sq, output_text_sq], outputs=[generate_btn_sq, clear_btn_sq], ) # Event handlers for button actions to process and clear text generate_btn_sq.click( fn=generate, inputs=[ input_text_sq, temperature_sq, max_tokens_sq, top_k_sq, repetition_penalty_sq, top_p_sq, ], outputs=output_text_sq, ) clear_btn_sq.click( fn=lambda: ("", ""), inputs=None, outputs=[input_text_sq, output_text_sq], ) with gr.Column(visible=False) as formalizer_column: input_text = gr.Textbox( placeholder="Wpisz tekst tutaj...", label="Twój tekst", lines=5 ) with gr.Row(): gr.Text( "Wybierz styl:", elem_id="style-label", show_label=False, elem_classes="same-height", ) style_dropdown = gr.Dropdown( choices=["Formalny", "Nieformalny", "Neutralny"], value="Neutralny", # Set a default value elem_id="style-dropdown", show_label=False, elem_classes="same-height", ) generate_btn = gr.Button( "Generuj", interactive=False, elem_id="generate-btn", elem_classes="same-height", ) clear_btn = gr.Button( "Wyczyść", interactive=False, elem_id="clear-btn", elem_classes="same-height", ) output_text = gr.Textbox(label="Wynik", interactive=False, lines=5) # Update button states based on input and output text changes for interactivity def update_button_states(input_text, output_text): return [ gr.update(interactive=bool(input_text)), gr.update(interactive=bool(input_text or output_text)), gr.update(interactive=bool(output_text)), ] input_text.change( update_button_states, inputs=[input_text, output_text], outputs=[generate_btn, clear_btn], ) output_text.change( update_button_states, inputs=[input_text, output_text], outputs=[generate_btn, clear_btn], ) # Event handlers for button actions to process and clear text def format_with_style(text, style): partial_text = "" for chunk in generate( text, temperature=0.3, max_tokens=1024, top_k=40, repetition_penalty=1.1, top_p=0.95, prompt_style=STYLE_PROMPTS[style] ): partial_text = chunk yield partial_text generate_btn.click( fn=format_with_style, inputs=[input_text, style_dropdown], outputs=output_text, ) clear_btn.click( fn=lambda: ("", ""), inputs=None, outputs=[input_text, output_text] ) # Placeholder for Judge content column, initially hidden with gr.Column(visible=False) as judge_column: gr.Markdown("Judge tool content goes here.") with gr.Accordion("⚙️ Parametry", open=False): temperature_jg = gr.Slider(0, 1, 0.3, step=0.1, label="Temperatura") max_tokens_jg = gr.Slider( 128, 4096, 1024, label="Maksymalna długość odpowiedzi" ) top_k_jg = gr.Slider(1, 80, 40, step=1, label="Top K") repetition_penalty_jg = gr.Slider( 0, 2, 1.1, step=0.1, label="Penalizacja powtórzeń" ) top_p_jg = gr.Slider(0, 1, 0.95, step=0.05, label="Top P") formalizer_btn.click( lambda: switch_tool("Formalizer"), outputs=[ formalizer_btn, judge_btn, simple_question_btn, formalizer_column, judge_column, simple_question_column, ], ) judge_btn.click( lambda: switch_tool("Judge"), outputs=[ formalizer_btn, judge_btn, simple_question_btn, formalizer_column, judge_column, simple_question_column, ], ) simple_question_btn.click( lambda: switch_tool("Simple Question"), outputs=[ formalizer_btn, judge_btn, simple_question_btn, formalizer_column, judge_column, simple_question_column, ], ) formalizer_btn.click( lambda: switch_tool("Formalizer"), outputs=[ formalizer_btn, judge_btn, simple_question_btn, formalizer_column, judge_column, simple_question_column, ], ) judge_btn.click( lambda: switch_tool("Judge"), outputs=[ formalizer_btn, judge_btn, simple_question_btn, formalizer_column, judge_column, simple_question_column, ], ) simple_question_btn.click( lambda: switch_tool("Simple Question"), outputs=[ formalizer_btn, judge_btn, simple_question_btn, formalizer_column, judge_column, simple_question_column, ], ) demo.queue().launch()