import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer import time import os # Attempt to import the spaces GPU decorator. # This is a common pattern, but the exact import might vary or be injected. try: import spaces # This might make spaces.GPU available except ImportError: spaces = None # Define it as None if import fails, so we can check later print("WARNING: 'spaces' module not found. @spaces.GPU decorator might not be available or work as expected.") # --- Configuration --- BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Confirmed by you as correct SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: 1. Provide personalized movie recommendations based on user preferences 2. Give brief, compelling rationales for why you recommend each movie 3. Ask thoughtful follow-up questions to better understand user tastes 4. Maintain an enthusiastic but not overwhelming tone about cinema When recommending movies, always explain WHY the movie fits their preferences.""" SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." # --- Global Model Storage (placeholders) --- # We will store model objects here after they are loaded within a GPU context. # This addresses John6666's point about global variables not updating correctly # if modified outside the main Gradio event flow or GPU context. # We'll treat these more like a cache that's populated by GPU-context functions. MODELS_LOADED = { "base_model": None, "base_tokenizer": None, "ft_model": None, "ft_tokenizer": None, "base_load_error": None, "ft_load_error": None, } # --- Core Model Loading and Inference Logic (to be wrapped by @spaces.GPU) --- def _load_and_infer(message: str, chat_history: list, model_id_to_load: str, system_prompt: str, model_kind: str): """ This function handles loading (if necessary) and inference. It's designed to be called by a function decorated with @spaces.GPU. """ model_key = f"{model_kind}_model" tokenizer_key = f"{model_kind}_tokenizer" error_key = f"{model_kind}_load_error" # Check if model failed to load previously if MODELS_LOADED[error_key]: yield f"Previous attempt to load {model_kind} model ({model_id_to_load}) failed: {MODELS_LOADED[error_key]}" return # Load model and tokenizer if not already loaded if MODELS_LOADED[model_key] is None or MODELS_LOADED[tokenizer_key] is None: print(f"Attempting to load {model_kind} model: {model_id_to_load} (Type: {type(model_id_to_load)})") if not model_id_to_load or not isinstance(model_id_to_load, str): MODELS_LOADED[error_key] = f"Invalid model ID: {model_id_to_load}" yield f"Error: {model_kind} model ID is not configured correctly ({model_id_to_load})." return try: tokenizer = AutoTokenizer.from_pretrained(model_id_to_load, trust_remote_code=True) # On ZeroGPU, device_map="auto" should leverage the @spaces.GPU context model = AutoModelForCausalLM.from_pretrained( model_id_to_load, torch_dtype=torch.bfloat16, # Qwen models often prefer bfloat16 device_map="auto", trust_remote_code=True, ) model.eval() if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id MODELS_LOADED[model_key] = model MODELS_LOADED[tokenizer_key] = tokenizer print(f"Successfully loaded and cached {model_kind} model and tokenizer.") except Exception as e: MODELS_LOADED[error_key] = str(e) print(f"ERROR loading {model_kind} model ({model_id_to_load}): {e}") yield f"Error loading {model_kind} model: {e}" # Yield error to Gradio return # Stop further execution for this call # Retrieve from cache model = MODELS_LOADED[model_key] tokenizer = MODELS_LOADED[tokenizer_key] if model is None or tokenizer is None: # Should not happen if loading was successful yield f"Model or tokenizer for {model_kind} is unexpectedly None after loading attempt." return # Prepare conversation conversation = [{"role": "system", "content": system_prompt}] if system_prompt else [] conversation.extend(chat_history) conversation.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) eos_tokens_ids = [tokenizer.eos_token_id] im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") if im_end_id != getattr(tokenizer, 'unk_token_id', None) and im_end_id not in eos_tokens_ids: eos_tokens_ids.append(im_end_id) eos_tokens_ids = list(set(eos_tokens_ids)) # Remove duplicates try: generated_token_ids = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids ) new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip() full_response = "" for char_idx, char_val in enumerate(response_text): full_response += char_val # Yield more slowly or in chunks if char-by-char is too slow/frequent for Gradio if char_idx % 5 == 0 or char_idx == len(response_text) -1 : # Yield every 5 chars or at the end time.sleep(0.001) # Minimal sleep yield full_response if not response_text: # Handle empty generation yield "" except Exception as e: print(f"Error during {model_kind} model generation: {e}") yield f"Error during generation: {e}" # --- Gradio Event Handler Wrappers (these get decorated) --- def create_gpu_handler(model_id, system_prompt, model_kind_str): # This function will be decorated by @spaces.GPU # It calls the actual logic. def gpu_fn(message, chat_history): yield from _load_and_infer(message, chat_history, model_id, system_prompt, model_kind_str) return gpu_fn # Apply the decorator IF `spaces` module was imported and has `GPU` if spaces and hasattr(spaces, "GPU"): print("Applying @spaces.GPU decorator.") base_model_predict = spaces.GPU(create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base")) ft_model_predict = spaces.GPU(create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft")) else: print("WARNING: @spaces.GPU decorator not applied. GPU acceleration on ZeroGPU might not work as expected.") # Fallback to non-decorated calls; this will likely lead to "No @spaces.GPU function detected" # or CUDA errors if running on ZeroGPU that expects the decorator. base_model_predict = create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base") ft_model_predict = create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft") # --- Gradio UI Definition --- with gr.Blocks(theme=gr.themes.Default()) as demo: # Changed to Default theme, Soft can sometimes have issues gr.Markdown( f""" # 🎬 CineGuide vs. Base {BASE_MODEL_ID} Compare the fine-tuned CineGuide (`{FINETUNED_MODEL_ID}`) with the base {BASE_MODEL_ID}. **Note:** Models are loaded on first use within a GPU context and may take time. This Space attempts to use the ZeroGPU shared pool via `@spaces.GPU`. """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}") chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages") with gr.Column(scale=1): gr.Markdown(f"## 🤖 Fine-tuned CineGuide") chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages") with gr.Row(): shared_input_textbox = gr.Textbox( show_label=False, placeholder="Enter your movie query...", container=False, scale=7 ) submit_button = gr.Button("✉️ Send", variant="primary", scale=1) gr.Examples( examples=[ "Hi! I'm looking for something funny to watch tonight.", "I love dry, witty humor more than slapstick.", "I'm really into complex sci-fi movies that make you think.", "Tell me about some good action movies from the 90s.", "Recommend a thought-provoking sci-fi film about AI.", ], inputs=[shared_input_textbox], label="Example Prompts" ) # Event handling # The `base_model_predict` and `ft_model_predict` are now the (potentially) decorated functions. submit_button.click( base_model_predict, [shared_input_textbox, chatbot_base], [chatbot_base], api_name="base_predict" # Good for testing API route ) submit_button.click( ft_model_predict, [shared_input_textbox, chatbot_ft], [chatbot_ft], api_name="ft_predict" ) shared_input_textbox.submit( base_model_predict, [shared_input_textbox, chatbot_base], [chatbot_base] ) shared_input_textbox.submit( ft_model_predict, [shared_input_textbox, chatbot_ft], [chatbot_ft] ) def clear_textbox_fn(): return "" submit_button.click(clear_textbox_fn, [], [shared_input_textbox], queue=False) # queue=False for instant clear shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox], queue=False) if __name__ == "__main__": demo.queue() # Enable queuing for multiple users # debug=True can sometimes interfere with production Spaces, but fine for testing demo.launch(debug=True)