Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
| import torch | |
| from threading import Thread | |
| # Model and device configuration | |
| phi4_model_path = "Compumacy/OpenBioLLm-70B" | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # === INITIALIZE EMPTY WEIGHTS === | |
| init_empty_weights() | |
| # === CONFIGURE 4-BIT QUANTIZATION === | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| # === LOAD MODEL WITH QUANTIZATION === | |
| model = AutoModelForCausalLM.from_pretrained( | |
| phi4_model_path, | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
| # === OFFLOAD TO CPU/DISK === | |
| model = load_checkpoint_and_dispatch( | |
| model, | |
| phi4_model_path, | |
| device_map="auto", | |
| offload_folder="offload", | |
| offload_state_dict=True, | |
| max_memory={**{i: "12GB" for i in range(torch.cuda.device_count())}, "cpu": "30GB"} | |
| ) | |
| # Enable gradient checkpointing if ever fine-tuning | |
| model.gradient_checkpointing_enable() | |
| # Optionally compile for PyTorch >= 2.0 | |
| try: | |
| model = torch.compile(model) | |
| except Exception: | |
| pass | |
| # === RESPONSE GENERATOR === | |
| def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): | |
| if not user_message.strip(): | |
| return history_state, history_state | |
| # Prompt setup | |
| system_message = ( | |
| "Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..." | |
| ) | |
| start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>" | |
| prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
| for msg in history_state: | |
| tag = msg["role"] | |
| content = msg["content"] | |
| prompt += f"{start_tag}{tag}{sep_tag}{content}{end_tag}" | |
| prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # Streaming setup | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| generation_kwargs = { | |
| "input_ids": inputs.input_ids, | |
| "attention_mask": inputs.attention_mask, | |
| "max_new_tokens": int(max_tokens), | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": int(top_k), | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "streamer": streamer | |
| } | |
| # Run generation in thread | |
| Thread(target=model.generate, kwargs=generation_kwargs).start() | |
| assistant_response = "" | |
| new_history = history_state + [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": ""} | |
| ] | |
| # Stream tokens | |
| for token in streamer: | |
| clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "") | |
| assistant_response += clean | |
| new_history[-1]["content"] = assistant_response | |
| yield new_history, new_history | |
| yield new_history, new_history | |
| # === EXAMPLE MESSAGES === | |
| example_messages = { | |
| "Math reasoning": "If a rectangular prism has a length of 6 cm...", | |
| "Logic puzzle": "Four people (Alex, Blake, Casey, ...)", | |
| "Physics problem": "A ball is thrown upward with an initial velocity..." | |
| } | |
| # === GRADIO APP === | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # Phi-4 Chat | |
| Try the example problems below to see how the model breaks down complex reasoning. | |
| """ ) | |
| history_state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") | |
| top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") | |
| top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") | |
| repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(label="Chat", type="messages") | |
| with gr.Row(): | |
| user_input = gr.Textbox(placeholder="Type your message...", scale=3) | |
| submit_button = gr.Button("Send", variant="primary", scale=1) | |
| clear_button = gr.Button("Clear", scale=1) | |
| gr.Markdown("**Try these examples:**") | |
| with gr.Row(): | |
| for name in example_messages: | |
| btn = gr.Button(name) | |
| btn.click(fn=lambda n=name: gr.update(value=example_messages[n]), inputs=None, outputs=user_input) | |
| submit_button.click( | |
| fn=generate_response, | |
| inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], | |
| outputs=[chatbot, history_state] | |
| ).then(lambda: gr.update(value=""), None, user_input) | |
| clear_button.click(lambda: ([], []), None, [chatbot, history_state]) | |
| demo.launch(ssr_mode=False) | |