import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from typing import List, Dict, Optional, Tuple DESCRIPTION = """ # QwQ Distill """ css = ''' h1 { text-align: center; display: block; } #duplicate-button { margin: auto; color: #fff; background: #1565c0; border-radius: 100vh; } ''' MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.config.sliding_window = 4096 model.eval() # Set the pad token ID if it's not already set if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id # Define roles for the chat class Role: SYSTEM = "system" USER = "user" ASSISTANT = "assistant" # Default system message default_system = "You are a helpful assistant." def clear_session() -> List: return "", [] def modify_system_session(system: str) -> Tuple[str, str, List]: if system is None or len(system) == 0: system = default_system return system, system, [] def history_to_messages(history: List, system: str) -> List[Dict]: messages = [{'role': Role.SYSTEM, 'content': system}] for h in history: messages.append({'role': Role.USER, 'content': h[0]}) messages.append({'role': Role.ASSISTANT, 'content': h[1]}) return messages @spaces.GPU(duration=120) def generate( query: Optional[str], history: Optional[List], system: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: if query is None: query = '' if history is None: history = [] # Convert history to messages messages = history_to_messages(history, system) messages.append({'role': Role.USER, 'content': query}) # Apply chat template and tokenize text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Set up the streamer for real-time text generation streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( **model_inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.pad_token_id, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Stream the output tokens outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) demo = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Textbox(label="System Message", value=default_system, lines=2), gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Write a Python function to reverses a string if it's length is a multiple of 4."], ["What is the volume of a pyramid with a rectangular base?"], ["Explain the difference between List comprehension and Lambda in Python."], ["What happens when the sun goes down?"], ], cache_examples=False, description=DESCRIPTION, css=css, fill_height=True, ) if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)