Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
import time | |
import os | |
# Attempt to import the spaces GPU decorator. | |
# This is a common pattern, but the exact import might vary or be injected. | |
try: | |
import spaces # This might make spaces.GPU available | |
except ImportError: | |
spaces = None # Define it as None if import fails, so we can check later | |
print("WARNING: 'spaces' module not found. @spaces.GPU decorator might not be available or work as expected.") | |
# --- Configuration --- | |
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" | |
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Confirmed by you as correct | |
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: | |
1. Provide personalized movie recommendations based on user preferences | |
2. Give brief, compelling rationales for why you recommend each movie | |
3. Ask thoughtful follow-up questions to better understand user tastes | |
4. Maintain an enthusiastic but not overwhelming tone about cinema | |
When recommending movies, always explain WHY the movie fits their preferences.""" | |
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." | |
# --- Global Model Storage (placeholders) --- | |
# We will store model objects here after they are loaded within a GPU context. | |
# This addresses John6666's point about global variables not updating correctly | |
# if modified outside the main Gradio event flow or GPU context. | |
# We'll treat these more like a cache that's populated by GPU-context functions. | |
MODELS_LOADED = { | |
"base_model": None, | |
"base_tokenizer": None, | |
"ft_model": None, | |
"ft_tokenizer": None, | |
"base_load_error": None, | |
"ft_load_error": None, | |
} | |
# --- Core Model Loading and Inference Logic (to be wrapped by @spaces.GPU) --- | |
def _load_and_infer(message: str, chat_history: list, model_id_to_load: str, system_prompt: str, model_kind: str): | |
""" | |
This function handles loading (if necessary) and inference. | |
It's designed to be called by a function decorated with @spaces.GPU. | |
""" | |
model_key = f"{model_kind}_model" | |
tokenizer_key = f"{model_kind}_tokenizer" | |
error_key = f"{model_kind}_load_error" | |
# Check if model failed to load previously | |
if MODELS_LOADED[error_key]: | |
yield f"Previous attempt to load {model_kind} model ({model_id_to_load}) failed: {MODELS_LOADED[error_key]}" | |
return | |
# Load model and tokenizer if not already loaded | |
if MODELS_LOADED[model_key] is None or MODELS_LOADED[tokenizer_key] is None: | |
print(f"Attempting to load {model_kind} model: {model_id_to_load} (Type: {type(model_id_to_load)})") | |
if not model_id_to_load or not isinstance(model_id_to_load, str): | |
MODELS_LOADED[error_key] = f"Invalid model ID: {model_id_to_load}" | |
yield f"Error: {model_kind} model ID is not configured correctly ({model_id_to_load})." | |
return | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_id_to_load, trust_remote_code=True) | |
# On ZeroGPU, device_map="auto" should leverage the @spaces.GPU context | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id_to_load, | |
torch_dtype=torch.bfloat16, # Qwen models often prefer bfloat16 | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
model.eval() | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
MODELS_LOADED[model_key] = model | |
MODELS_LOADED[tokenizer_key] = tokenizer | |
print(f"Successfully loaded and cached {model_kind} model and tokenizer.") | |
except Exception as e: | |
MODELS_LOADED[error_key] = str(e) | |
print(f"ERROR loading {model_kind} model ({model_id_to_load}): {e}") | |
yield f"Error loading {model_kind} model: {e}" # Yield error to Gradio | |
return # Stop further execution for this call | |
# Retrieve from cache | |
model = MODELS_LOADED[model_key] | |
tokenizer = MODELS_LOADED[tokenizer_key] | |
if model is None or tokenizer is None: # Should not happen if loading was successful | |
yield f"Model or tokenizer for {model_kind} is unexpectedly None after loading attempt." | |
return | |
# Prepare conversation | |
conversation = [{"role": "system", "content": system_prompt}] if system_prompt else [] | |
conversation.extend(chat_history) | |
conversation.append({"role": "user", "content": message}) | |
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) | |
eos_tokens_ids = [tokenizer.eos_token_id] | |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
if im_end_id != getattr(tokenizer, 'unk_token_id', None) and im_end_id not in eos_tokens_ids: | |
eos_tokens_ids.append(im_end_id) | |
eos_tokens_ids = list(set(eos_tokens_ids)) # Remove duplicates | |
try: | |
generated_token_ids = model.generate( | |
**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, | |
repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids | |
) | |
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] | |
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip() | |
full_response = "" | |
for char_idx, char_val in enumerate(response_text): | |
full_response += char_val | |
# Yield more slowly or in chunks if char-by-char is too slow/frequent for Gradio | |
if char_idx % 5 == 0 or char_idx == len(response_text) -1 : # Yield every 5 chars or at the end | |
time.sleep(0.001) # Minimal sleep | |
yield full_response | |
if not response_text: # Handle empty generation | |
yield "" | |
except Exception as e: | |
print(f"Error during {model_kind} model generation: {e}") | |
yield f"Error during generation: {e}" | |
# --- Gradio Event Handler Wrappers (these get decorated) --- | |
def create_gpu_handler(model_id, system_prompt, model_kind_str): | |
# This function will be decorated by @spaces.GPU | |
# It calls the actual logic. | |
def gpu_fn(message, chat_history): | |
yield from _load_and_infer(message, chat_history, model_id, system_prompt, model_kind_str) | |
return gpu_fn | |
# Apply the decorator IF `spaces` module was imported and has `GPU` | |
if spaces and hasattr(spaces, "GPU"): | |
print("Applying @spaces.GPU decorator.") | |
base_model_predict = spaces.GPU(create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base")) | |
ft_model_predict = spaces.GPU(create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft")) | |
else: | |
print("WARNING: @spaces.GPU decorator not applied. GPU acceleration on ZeroGPU might not work as expected.") | |
# Fallback to non-decorated calls; this will likely lead to "No @spaces.GPU function detected" | |
# or CUDA errors if running on ZeroGPU that expects the decorator. | |
base_model_predict = create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base") | |
ft_model_predict = create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft") | |
# --- Gradio UI Definition --- | |
with gr.Blocks(theme=gr.themes.Default()) as demo: # Changed to Default theme, Soft can sometimes have issues | |
gr.Markdown( | |
f""" | |
# π¬ CineGuide vs. Base {BASE_MODEL_ID} | |
Compare the fine-tuned CineGuide (`{FINETUNED_MODEL_ID}`) with the base {BASE_MODEL_ID}. | |
**Note:** Models are loaded on first use within a GPU context and may take time. | |
This Space attempts to use the ZeroGPU shared pool via `@spaces.GPU`. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(f"## π£οΈ Base {BASE_MODEL_ID}") | |
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages") | |
with gr.Column(scale=1): | |
gr.Markdown(f"## π€ Fine-tuned CineGuide") | |
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages") | |
with gr.Row(): | |
shared_input_textbox = gr.Textbox( | |
show_label=False, placeholder="Enter your movie query...", container=False, scale=7 | |
) | |
submit_button = gr.Button("βοΈ Send", variant="primary", scale=1) | |
gr.Examples( | |
examples=[ | |
"Hi! I'm looking for something funny to watch tonight.", | |
"I love dry, witty humor more than slapstick.", | |
"I'm really into complex sci-fi movies that make you think.", | |
"Tell me about some good action movies from the 90s.", | |
"Recommend a thought-provoking sci-fi film about AI.", | |
], | |
inputs=[shared_input_textbox], label="Example Prompts" | |
) | |
# Event handling | |
# The `base_model_predict` and `ft_model_predict` are now the (potentially) decorated functions. | |
submit_button.click( | |
base_model_predict, | |
[shared_input_textbox, chatbot_base], | |
[chatbot_base], | |
api_name="base_predict" # Good for testing API route | |
) | |
submit_button.click( | |
ft_model_predict, | |
[shared_input_textbox, chatbot_ft], | |
[chatbot_ft], | |
api_name="ft_predict" | |
) | |
shared_input_textbox.submit( | |
base_model_predict, | |
[shared_input_textbox, chatbot_base], | |
[chatbot_base] | |
) | |
shared_input_textbox.submit( | |
ft_model_predict, | |
[shared_input_textbox, chatbot_ft], | |
[chatbot_ft] | |
) | |
def clear_textbox_fn(): return "" | |
submit_button.click(clear_textbox_fn, [], [shared_input_textbox], queue=False) # queue=False for instant clear | |
shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox], queue=False) | |
if __name__ == "__main__": | |
demo.queue() # Enable queuing for multiple users | |
# debug=True can sometimes interfere with production Spaces, but fine for testing | |
demo.launch(debug=True) |