serhany's picture
Update app.py
18449fc verified
raw
history blame
12.5 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import time
import os # Keep os, it might be useful
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Assuming this is correct
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema
When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
# --- Global Model Cache (models will be loaded on first use) ---
_models_cache = {
"base": None,
"finetuned": None,
"tokenizer_base": None,
"tokenizer_ft": None,
}
# --- Model Loading Function (to be called inside decorated functions) ---
def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
"""Loads a model and tokenizer if not already in cache."""
if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
return _models_cache[model_key], _models_cache[tokenizer_key]
print(f"Loading {model_key} model ({model_identifier})...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_identifier,
torch_dtype=torch.bfloat16, # Or torch.float16 if better for available GPU
device_map="auto", # This will utilize the GPU allocated by @spaces.GPU
trust_remote_code=True,
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
_models_cache[model_key] = model
_models_cache[tokenizer_key] = tokenizer
print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
return model, tokenizer
except Exception as e:
print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
_models_cache[model_key] = "error" # Mark as error to avoid retrying
_models_cache[tokenizer_key] = "error"
raise # Re-raise the exception to see it in Gradio UI or logs
# --- Inference Function (modified to ensure models are loaded) ---
def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
model, tokenizer = None, None
system_prompt = ""
if model_type_to_load == "base":
if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
system_prompt = SYSTEM_PROMPT_BASE
elif model_type_to_load == "finetuned":
# Critical check for the FINETUNED_MODEL_ID itself
if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID} (Type: {type(FINETUNED_MODEL_ID)})")
yield "Error: Fine-tuned model ID is not configured correctly."
return
if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
system_prompt = SYSTEM_PROMPT_CINEGUIDE
else:
yield "Invalid model type."
return
if model is None or tokenizer is None: # Should be caught by "error" check or exception above
yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
return
conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
conversation.extend(chat_history) # Assuming chat_history is already type="messages"
conversation.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
eos_tokens_ids = [tokenizer.eos_token_id]
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_id != getattr(tokenizer, 'unk_token_id', None): # Check if <|im_end|> is in vocab
eos_tokens_ids.append(im_end_id)
# Remove duplicates just in case eos_token_id is the same as im_end_id
eos_tokens_ids = list(set(eos_tokens_ids))
generated_token_ids = model.generate(
**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9,
repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids
)
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
full_response = ""
for char in response_text:
full_response += char
time.sleep(0.005) # Adjust for desired speed
yield full_response
# --- Gradio UI Event Handlers (THESE GET DECORATED) ---
# Note: The @spaces.GPU decorator needs to be imported if not automatically available
# from huggingface_hub import-like syntax or it might be injected.
# For now, let's assume it's magically available in the Space environment.
# If not, you might need to find how to import it for ZeroGPU shared pool.
# It's often available as: `from Fg.spaces import GPU` and used as `@GPU`
# or simply `@spaces.GPU` if `spaces` is an auto-imported object.
# Try without explicit import first, as HF might inject it.
# If "spaces is not defined" error, you'll need to find the correct import for it.
# @spaces.GPU # Placeholder for actual decorator
@gr.ΰ₯ΰ€°ΰ₯ # This is a Gradio decorator for functions, not the HF GPU one.
# We need to find the correct HF spaces GPU decorator.
# For now, I'll structure as if it exists.
# The actual execution of model loading and generation will happen here.
# It's common to decorate the function called by the Gradio event.
# Let's try decorating the prediction functions.
# If `@spaces.GPU` is not found, the app will error earlier. You might need to find its import from HF docs for ZeroGPU.
# `from hf_spaces_shared_gpu import gpu_heavy_task` is a made-up example.
# Let's assume for now that if the hardware is "ZeroGPU" and this decorator is required,
# the Hugging Face platform makes `spaces.GPU` available.
def base_model_predict_decorated(user_message, chat_history):
# This function will now be responsible for triggering the load and then generating.
try:
# Model loading now happens here, within the GPU-allocated function
# The generate_chat_response will call load_model_and_tokenizer internally if needed
bot_response_stream = generate_chat_response(user_message, chat_history, "base")
full_bot_message = ""
for chunk in bot_response_stream:
full_bot_message = chunk
yield full_bot_message
except Exception as e:
print(f"Error in base_model_predict_decorated: {e}")
yield f"Error generating base model response: {e}"
def ft_model_predict_decorated(user_message, chat_history):
try:
# Model loading now happens here
bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
full_bot_message = ""
for chunk in bot_response_stream:
full_bot_message = chunk
yield full_bot_message
except Exception as e:
print(f"Error in ft_model_predict_decorated: {e}")
yield f"Error generating fine-tuned response: {e}"
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🎬 CineGuide vs. Base {BASE_MODEL_ID}
Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`)
with the base {BASE_MODEL_ID} model.
Type your movie-related query below and see how each model responds!
**Note:** Models are loaded on first use and may take some time. Using shared GPU pool.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"## πŸ—£οΈ Base {BASE_MODEL_ID}")
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages") # Use type="messages"
with gr.Column(scale=1):
gr.Markdown(f"## πŸ€– Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})")
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages") # Use type="messages"
with gr.Row():
shared_input_textbox = gr.Textbox(
show_label=False, placeholder="Enter your movie query...", container=False, scale=7
)
submit_button = gr.Button("βœ‰οΈ Send", variant="primary", scale=1)
gr.Examples(
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick.",
"I'm really into complex sci-fi movies that make you think.",
],
inputs=[shared_input_textbox], label="Example Prompts"
)
# Apply the @spaces.GPU decorator if you find the correct way to import/use it.
# For now, the functions themselves will handle loading.
# If the decorator is `@spaces.GPU()`, it would be:
# submit_button.click(spaces.GPU()(base_model_predict_decorated), ...)
# This part is tricky without knowing the exact decorator syntax for ZeroGPU.
# Let's assume the functions are called and *they* handle the GPU context internally.
# If the platform *requires* the event handler itself to be decorated, that's a different structure.
# The functions `base_model_predict_decorated` and `ft_model_predict_decorated`
# are what Gradio will call. If these need the `@spaces.GPU` decorator, you'd apply it like:
# @spaces.GPU
# def decorated_base_predict(user_message, chat_history):
# yield from base_model_predict_decorated(user_message, chat_history)
# And then pass `decorated_base_predict` to `submit_button.click`
# Simpler approach for now: let Gradio call these directly.
# If a wrapper is needed for the decorator, we can add it.
submit_button.click(
base_model_predict_decorated,
[shared_input_textbox, chatbot_base],
[chatbot_base],
# api_name="base_predict" # Optional
)
submit_button.click(
ft_model_predict_decorated,
[shared_input_textbox, chatbot_ft],
[chatbot_ft],
# api_name="ft_predict" # Optional
)
# Handle textbox submit event for both
shared_input_textbox.submit(
base_model_predict_decorated,
[shared_input_textbox, chatbot_base],
[chatbot_base]
)
shared_input_textbox.submit(
ft_model_predict_decorated,
[shared_input_textbox, chatbot_ft],
[chatbot_ft]
)
def clear_textbox_fn(): return ""
submit_button.click(clear_textbox_fn, [], [shared_input_textbox])
shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox])
if __name__ == "__main__":
# The following line is usually specific to certain Space configurations.
# For ZeroGPU with @spaces.GPU, this might be needed in the README.md/config.yaml
# rather than here, or the decorator itself implies it.
# demo.config(dependencies=["torch", "transformers", "accelerate", ...])
# Check Gradio docs for how to make a function eligible for @spaces.GPU if it's not a direct event handler.
# Often, the main event handler itself is decorated.
demo.queue()
demo.launch(debug=True)