serhany's picture
Update app.py
f1ea8a0 verified
raw
history blame
10.3 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import time
import os
# Attempt to import the spaces GPU decorator.
# This is a common pattern, but the exact import might vary or be injected.
try:
import spaces # This might make spaces.GPU available
except ImportError:
spaces = None # Define it as None if import fails, so we can check later
print("WARNING: 'spaces' module not found. @spaces.GPU decorator might not be available or work as expected.")
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Confirmed by you as correct
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema
When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
# --- Global Model Storage (placeholders) ---
# We will store model objects here after they are loaded within a GPU context.
# This addresses John6666's point about global variables not updating correctly
# if modified outside the main Gradio event flow or GPU context.
# We'll treat these more like a cache that's populated by GPU-context functions.
MODELS_LOADED = {
"base_model": None,
"base_tokenizer": None,
"ft_model": None,
"ft_tokenizer": None,
"base_load_error": None,
"ft_load_error": None,
}
# --- Core Model Loading and Inference Logic (to be wrapped by @spaces.GPU) ---
def _load_and_infer(message: str, chat_history: list, model_id_to_load: str, system_prompt: str, model_kind: str):
"""
This function handles loading (if necessary) and inference.
It's designed to be called by a function decorated with @spaces.GPU.
"""
model_key = f"{model_kind}_model"
tokenizer_key = f"{model_kind}_tokenizer"
error_key = f"{model_kind}_load_error"
# Check if model failed to load previously
if MODELS_LOADED[error_key]:
yield f"Previous attempt to load {model_kind} model ({model_id_to_load}) failed: {MODELS_LOADED[error_key]}"
return
# Load model and tokenizer if not already loaded
if MODELS_LOADED[model_key] is None or MODELS_LOADED[tokenizer_key] is None:
print(f"Attempting to load {model_kind} model: {model_id_to_load} (Type: {type(model_id_to_load)})")
if not model_id_to_load or not isinstance(model_id_to_load, str):
MODELS_LOADED[error_key] = f"Invalid model ID: {model_id_to_load}"
yield f"Error: {model_kind} model ID is not configured correctly ({model_id_to_load})."
return
try:
tokenizer = AutoTokenizer.from_pretrained(model_id_to_load, trust_remote_code=True)
# On ZeroGPU, device_map="auto" should leverage the @spaces.GPU context
model = AutoModelForCausalLM.from_pretrained(
model_id_to_load,
torch_dtype=torch.bfloat16, # Qwen models often prefer bfloat16
device_map="auto",
trust_remote_code=True,
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
MODELS_LOADED[model_key] = model
MODELS_LOADED[tokenizer_key] = tokenizer
print(f"Successfully loaded and cached {model_kind} model and tokenizer.")
except Exception as e:
MODELS_LOADED[error_key] = str(e)
print(f"ERROR loading {model_kind} model ({model_id_to_load}): {e}")
yield f"Error loading {model_kind} model: {e}" # Yield error to Gradio
return # Stop further execution for this call
# Retrieve from cache
model = MODELS_LOADED[model_key]
tokenizer = MODELS_LOADED[tokenizer_key]
if model is None or tokenizer is None: # Should not happen if loading was successful
yield f"Model or tokenizer for {model_kind} is unexpectedly None after loading attempt."
return
# Prepare conversation
conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
conversation.extend(chat_history)
conversation.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
eos_tokens_ids = [tokenizer.eos_token_id]
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_id != getattr(tokenizer, 'unk_token_id', None) and im_end_id not in eos_tokens_ids:
eos_tokens_ids.append(im_end_id)
eos_tokens_ids = list(set(eos_tokens_ids)) # Remove duplicates
try:
generated_token_ids = model.generate(
**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9,
repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids
)
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
full_response = ""
for char_idx, char_val in enumerate(response_text):
full_response += char_val
# Yield more slowly or in chunks if char-by-char is too slow/frequent for Gradio
if char_idx % 5 == 0 or char_idx == len(response_text) -1 : # Yield every 5 chars or at the end
time.sleep(0.001) # Minimal sleep
yield full_response
if not response_text: # Handle empty generation
yield ""
except Exception as e:
print(f"Error during {model_kind} model generation: {e}")
yield f"Error during generation: {e}"
# --- Gradio Event Handler Wrappers (these get decorated) ---
def create_gpu_handler(model_id, system_prompt, model_kind_str):
# This function will be decorated by @spaces.GPU
# It calls the actual logic.
def gpu_fn(message, chat_history):
yield from _load_and_infer(message, chat_history, model_id, system_prompt, model_kind_str)
return gpu_fn
# Apply the decorator IF `spaces` module was imported and has `GPU`
if spaces and hasattr(spaces, "GPU"):
print("Applying @spaces.GPU decorator.")
base_model_predict = spaces.GPU(create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base"))
ft_model_predict = spaces.GPU(create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft"))
else:
print("WARNING: @spaces.GPU decorator not applied. GPU acceleration on ZeroGPU might not work as expected.")
# Fallback to non-decorated calls; this will likely lead to "No @spaces.GPU function detected"
# or CUDA errors if running on ZeroGPU that expects the decorator.
base_model_predict = create_gpu_handler(BASE_MODEL_ID, SYSTEM_PROMPT_BASE, "base")
ft_model_predict = create_gpu_handler(FINETUNED_MODEL_ID, SYSTEM_PROMPT_CINEGUIDE, "ft")
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Default()) as demo: # Changed to Default theme, Soft can sometimes have issues
gr.Markdown(
f"""
# 🎬 CineGuide vs. Base {BASE_MODEL_ID}
Compare the fine-tuned CineGuide (`{FINETUNED_MODEL_ID}`) with the base {BASE_MODEL_ID}.
**Note:** Models are loaded on first use within a GPU context and may take time.
This Space attempts to use the ZeroGPU shared pool via `@spaces.GPU`.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"## πŸ—£οΈ Base {BASE_MODEL_ID}")
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages")
with gr.Column(scale=1):
gr.Markdown(f"## πŸ€– Fine-tuned CineGuide")
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages")
with gr.Row():
shared_input_textbox = gr.Textbox(
show_label=False, placeholder="Enter your movie query...", container=False, scale=7
)
submit_button = gr.Button("βœ‰οΈ Send", variant="primary", scale=1)
gr.Examples(
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick.",
"I'm really into complex sci-fi movies that make you think.",
"Tell me about some good action movies from the 90s.",
"Recommend a thought-provoking sci-fi film about AI.",
],
inputs=[shared_input_textbox], label="Example Prompts"
)
# Event handling
# The `base_model_predict` and `ft_model_predict` are now the (potentially) decorated functions.
submit_button.click(
base_model_predict,
[shared_input_textbox, chatbot_base],
[chatbot_base],
api_name="base_predict" # Good for testing API route
)
submit_button.click(
ft_model_predict,
[shared_input_textbox, chatbot_ft],
[chatbot_ft],
api_name="ft_predict"
)
shared_input_textbox.submit(
base_model_predict,
[shared_input_textbox, chatbot_base],
[chatbot_base]
)
shared_input_textbox.submit(
ft_model_predict,
[shared_input_textbox, chatbot_ft],
[chatbot_ft]
)
def clear_textbox_fn(): return ""
submit_button.click(clear_textbox_fn, [], [shared_input_textbox], queue=False) # queue=False for instant clear
shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox], queue=False)
if __name__ == "__main__":
demo.queue() # Enable queuing for multiple users
# debug=True can sometimes interfere with production Spaces, but fine for testing
demo.launch(debug=True)