import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer import time import os # --- Configuration --- BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" # Path to your merged fine-tuned model within the Hugging Face Space # If 'cineguide-merged' is at the root of your Space repo: FINETUNED_MODEL_PATH = "cineguide-merged" # System prompts SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: 1. Provide personalized movie recommendations based on user preferences 2. Give brief, compelling rationales for why you recommend each movie 3. Ask thoughtful follow-up questions to better understand user tastes 4. Maintain an enthusiastic but not overwhelming tone about cinema When recommending movies, always explain WHY the movie fits their preferences.""" SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." # --- Model Loading --- # Cache models globally so they are loaded only once _models_cache = {} def get_model_and_tokenizer(model_id_or_path): if model_id_or_path in _models_cache: return _models_cache[model_id_or_path] print(f"Loading model: {model_id_or_path}") tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id_or_path, torch_dtype=torch.bfloat16, # Use bfloat16 for faster inference device_map="auto", # Automatically distribute across GPUs if available trust_remote_code=True, # attn_implementation="flash_attention_2" # Optional: if supported by Space hardware & transformers version ) model.eval() # Set to evaluation mode if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id _models_cache[model_id_or_path] = (model, tokenizer) print(f"Finished loading: {model_id_or_path}") return model, tokenizer # Pre-load models when the script starts # This can take time, so Gradio might show a loading screen. # For Spaces, this happens during the build/startup phase. print("Pre-loading models...") try: model_base, tokenizer_base = get_model_and_tokenizer(BASE_MODEL_ID) print("Base model loaded.") except Exception as e: print(f"Error loading base model: {e}") model_base, tokenizer_base = None, None # Check if fine-tuned model path exists before loading if os.path.exists(FINETUNED_MODEL_PATH) and os.path.isdir(FINETUNED_MODEL_PATH): try: model_ft, tokenizer_ft = get_model_and_tokenizer(FINETUNED_MODEL_PATH) print("Fine-tuned model loaded.") except Exception as e: print(f"Error loading fine-tuned model from {FINETUNED_MODEL_PATH}: {e}") model_ft, tokenizer_ft = None, None else: print(f"Fine-tuned model path not found: {FINETUNED_MODEL_PATH}. Skipping fine-tuned model.") model_ft, tokenizer_ft = None, None print("Model pre-loading complete.") # --- Inference Function --- def generate_chat_response(message: str, chat_history: list, model_type: str): if model_type == "base": model, tokenizer = model_base, tokenizer_base system_prompt = SYSTEM_PROMPT_BASE elif model_type == "finetuned": model, tokenizer = model_ft, tokenizer_ft system_prompt = SYSTEM_PROMPT_CINEGUIDE else: yield "Invalid model type." return if model is None or tokenizer is None: yield f"Model '{model_type}' is not available." return conversation = [] if system_prompt: conversation.append({"role": "system", "content": system_prompt}) for user_msg, assistant_msg in chat_history: conversation.append({"role": "user", "content": user_msg}) conversation.append({"role": "assistant", "content": assistant_msg}) conversation.append({"role": "user", "content": message}) # Apply chat template prompt = tokenizer.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True # This adds the <|im_start|>assistant prefix ) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # For streaming, run generation in a separate thread # For Gradio, we can yield partial results # However, TextStreamer prints to stdout. For Gradio, we need to capture. # Simpler non-streaming approach for direct yield: # Remove streamer from generation_kwargs # outputs = model.generate(**generation_kwargs_without_streamer) # decoded_output = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) # yield decoded_output # More complex streaming for Gradio: full_response = "" generated_token_ids = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>")] ) # Decode only the newly generated tokens new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() response_text = response_text.replace("<|im_end|>", "").strip() # Clean up # Yield character by character for streaming effect (can be slow for long responses) # A better way is to yield chunks. For simplicity, this is char by char. for char in response_text: full_response += char time.sleep(0.005) # Adjust for desired speed yield full_response def respond_base(message, chat_history): # chat_history is a list of [user_msg, assistant_msg] yield from generate_chat_response(message, chat_history, "base") def respond_finetuned(message, chat_history): yield from generate_chat_response(message, chat_history, "finetuned") # --- Gradio UI --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🎬 CineGuide vs. Base Qwen2.5-7B-Instruct Compare the fine-tuned CineGuide movie recommender with the base Qwen2.5-7B-Instruct model. Type your movie-related query below and see how each model responds! """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## 🗣️ Base Qwen2.5-7B-Instruct") chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, bubble_full_width=False) if model_base is None: gr.Markdown("⚠️ Base model could not be loaded. This chat interface will not work.") with gr.Column(scale=1): gr.Markdown("## 🤖 Fine-tuned CineGuide (Qwen2.5-7B)") chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, bubble_full_width=False) if model_ft is None: gr.Markdown("⚠️ Fine-tuned model could not be loaded. This chat interface will not work.") with gr.Row(): shared_input_textbox = gr.Textbox( show_label=False, placeholder="Enter your movie query here and press Enter...", container=False, scale=7, # Make it wider ) submit_button = gr.Button("✉️ Send", variant="primary", scale=1) # clear_button = gr.Button("🗑️ Clear All", scale=1) # If you want a single clear button # Predefined examples gr.Examples( examples=[ "Hi! I'm looking for something funny to watch tonight.", "I love dry, witty humor more than slapstick. Think more British comedy style.", "I'm really into complex sci-fi movies that make you think. I loved Arrival and Blade Runner 2049.", "I need help planning a family movie night. We have kids aged 8, 11, and 14, plus adults.", "I'm going through a tough breakup and need something uplifting but not cheesy romantic.", "I loved Parasite and want to explore more international cinema. Where should I start?", ], inputs=[shared_input_textbox], # outputs=[chatbot_base, chatbot_ft], # Examples don't directly populate chatbots # fn=lambda x: (None, None), # Dummy function for examples label="Example Prompts (click to use)" ) # Event handlers def handle_submit(user_message, chat_history_base, chat_history_ft): # This will return iterators. Gradio handles them for streaming. # Important: chat_history is updated by Gradio automatically by returning (user_message, bot_message_chunk) # For simultaneous updates, we need to manage history carefully or use a trick. # Gradio's chatbot expects the history list to be updated. # The `respond_base` and `respond_finetuned` functions already take history. # The issue is that Gradio wants a function that returns the new state of the chatbot. # Simplest for simultaneous: return None for the other chatbot if we trigger one by one. # For true simultaneous, you'd need a more complex setup or separate submit buttons. # Let's make them update sequentially for simplicity with one input. # Update base model chat chat_history_base.append((user_message, None)) # Add user message # The `yield` from respond_base will update the last message (None) # Update fine-tuned model chat chat_history_ft.append((user_message, None)) # Add user message # We need to return generators that Gradio can iterate over # This won't work directly as Gradio expects outputs to be bound to specific components. # We need to make the function return the new state for *both* chatbots. # The `respond_base` and `respond_finetuned` should update their respective histories. # Gradio's Chatbot expects (message, history) -> history or (message, history) -> yield history_updates # Let's define wrapper functions for the submit action. return "", chat_history_base, chat_history_ft # Clear textbox, pass history def base_model_predict(user_message, chat_history): chat_history.append((user_message, "")) # Add user message and placeholder for bot for response_chunk in respond_base(user_message, chat_history[:-1]): # Pass history without current turn chat_history[-1] = (user_message, response_chunk) yield chat_history def ft_model_predict(user_message, chat_history): chat_history.append((user_message, "")) for response_chunk in respond_finetuned(user_message, chat_history[:-1]): chat_history[-1] = (user_message, response_chunk) yield chat_history # When shared_input_textbox is submitted or submit_button is clicked: if model_base is not None: shared_input_textbox.submit( base_model_predict, [shared_input_textbox, chatbot_base], [chatbot_base], ) submit_button.click( base_model_predict, [shared_input_textbox, chatbot_base], [chatbot_base], ) if model_ft is not None: shared_input_textbox.submit( ft_model_predict, [shared_input_textbox, chatbot_ft], [chatbot_ft], ) submit_button.click( ft_model_predict, [shared_input_textbox, chatbot_ft], [chatbot_ft], ) # After both predictions are done (or if one is skipped), clear the input textbox # This is a bit tricky with simultaneous submits. # A simpler way is to clear it on the second submit if both models are active. # Or, let Gradio handle textbox clearing by returning "" as the first element of the outputs list. # If ft_model_predict is the last one to be called from submit: if model_ft is not None: shared_input_textbox.submit(lambda: "", [], [shared_input_textbox]) submit_button.click(lambda: "", [], [shared_input_textbox]) elif model_base is not None: # If only base model is active shared_input_textbox.submit(lambda: "", [], [shared_input_textbox]) submit_button.click(lambda: "", [], [shared_input_textbox]) # Clear buttons (Individual) # clear_base_btn = gr.Button("🗑️ Clear Base Chat") # clear_ft_btn = gr.Button("🗑️ Clear CineGuide Chat") # clear_base_btn.click(lambda: (None, ""), None, [chatbot_base, shared_input_textbox], queue=False) # clear_ft_btn.click(lambda: (None, ""), None, [chatbot_ft, shared_input_textbox], queue=False) # --- Launch the App --- if __name__ == "__main__": demo.queue() # Enable queuing for handling multiple users demo.launch(debug=True, share=False) # share=True for public link if running locally