serhany's picture
Update app.py
6935641 verified
raw
history blame
10 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import spaces
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema
When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
# --- Global Model Cache ---
_models_cache = {
"base": None,
"finetuned": None,
"tokenizer_base": None,
"tokenizer_ft": None,
}
def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
"""Loads a model and tokenizer if not already in cache."""
if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
return _models_cache[model_key], _models_cache[tokenizer_key]
print(f"Loading {model_key} model ({model_identifier})...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_identifier,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
_models_cache[model_key] = model
_models_cache[tokenizer_key] = tokenizer
print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
return model, tokenizer
except Exception as e:
print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
_models_cache[model_key] = "error"
_models_cache[tokenizer_key] = "error"
raise
def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
"""Generate response using specified model type."""
model, tokenizer = None, None
system_prompt = ""
if model_type_to_load == "base":
if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
system_prompt = SYSTEM_PROMPT_BASE
elif model_type_to_load == "finetuned":
if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}")
yield "Error: Fine-tuned model ID is not configured correctly."
return
if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
system_prompt = SYSTEM_PROMPT_CINEGUIDE
else:
yield "Invalid model type."
return
if model is None or tokenizer is None:
yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
return
# Prepare conversation
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
# Add chat history
conversation.extend(chat_history)
conversation.append({"role": "user", "content": message})
# Generate response
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
# Prepare EOS tokens
eos_tokens_ids = [tokenizer.eos_token_id]
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_id != getattr(tokenizer, 'unk_token_id', None):
eos_tokens_ids.append(im_end_id)
eos_tokens_ids = list(set(eos_tokens_ids))
# Generate
with torch.no_grad():
generated_token_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_tokens_ids
)
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
# Stream the response
full_response = ""
for char in response_text:
full_response += char
time.sleep(0.005)
yield full_response
@spaces.GPU
def base_model_predict(user_message, chat_history):
"""Predict using base model - decorated with @spaces.GPU."""
try:
bot_response_stream = generate_chat_response(user_message, chat_history, "base")
for chunk in bot_response_stream:
yield chunk
except Exception as e:
print(f"Error in base_model_predict: {e}")
yield f"Error generating base model response: {str(e)}"
@spaces.GPU
def ft_model_predict(user_message, chat_history):
"""Predict using fine-tuned model - decorated with @spaces.GPU."""
try:
bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
for chunk in bot_response_stream:
yield chunk
except Exception as e:
print(f"Error in ft_model_predict: {e}")
yield f"Error generating fine-tuned response: {str(e)}"
def format_chat_history(history, message):
"""Format the chat history for the models."""
formatted_history = []
for chat in history:
if isinstance(chat, dict) and 'role' in chat:
formatted_history.append(chat)
elif isinstance(chat, list) and len(chat) == 2:
formatted_history.extend([
{"role": "user", "content": chat[0]},
{"role": "assistant", "content": chat[1]}
])
return formatted_history
def respond_base(message, history):
"""Handle base model response for Gradio ChatInterface."""
formatted_history = format_chat_history(history, message)
response_gen = base_model_predict(message, formatted_history)
for response in response_gen:
yield response
def respond_ft(message, history):
"""Handle fine-tuned model response for Gradio ChatInterface."""
formatted_history = format_chat_history(history, message)
response_gen = ft_model_predict(message, formatted_history)
for response in response_gen:
yield response
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo:
gr.Markdown(
f"""
# 🎬 CineGuide vs. Base Model Comparison
Compare the fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model.
**Base Model:** `{BASE_MODEL_ID}`
**Fine-tuned Model:** `{FINETUNED_MODEL_ID}`
Type your movie-related query below and see how each model responds!
⚠️ **Note:** Models are loaded on first use and may take 30-60 seconds initially.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"## 🗣️ Base Model")
gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*")
chatbot_base = gr.ChatInterface(
respond_base,
textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7),
title="",
description="",
theme="soft",
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick.",
"I'm really into complex sci-fi movies that make you think.",
"Can you recommend a good thriller?",
"What's a good romantic comedy from the 2000s?"
],
cache_examples=False,
retry_btn=None,
undo_btn="⤴️ Undo",
clear_btn="🗑️ Clear"
)
with gr.Column(scale=1):
gr.Markdown(f"## 🎬 CineGuide (Fine-tuned)")
gr.Markdown(f"*Specialized for movie recommendations*")
chatbot_ft = gr.ChatInterface(
respond_ft,
textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7),
title="",
description="",
theme="soft",
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick.",
"I'm really into complex sci-fi movies that make you think.",
"Can you recommend a good thriller?",
"What's a good romantic comedy from the 2000s?"
],
cache_examples=False,
retry_btn=None,
undo_btn="⤴️ Undo",
clear_btn="🗑️ Clear"
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch()