serhany's picture
Upload 2 files
0c08550 verified
raw
history blame
10.6 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import time
import os
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
# NOW, this points to your model on the Hugging Face Hub
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"
# System prompts (same as before)
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema
When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
# --- Model Loading ---
_models_cache = {}
def get_model_and_tokenizer(model_id_or_path, is_local_path=False): # Added is_local_path for flexibility
if model_id_or_path in _models_cache:
return _models_cache[model_id_or_path]
print(f"Loading model: {model_id_or_path}")
# For models from Hub, trust_remote_code is often needed for custom architectures like Qwen
# For local paths, it might also be needed if they were saved with trust_remote_code=True
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
# attn_implementation="flash_attention_2" # Optional
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Ensure pad_token_id is also set if pad_token is set
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
_models_cache[model_id_or_path] = (model, tokenizer)
print(f"Finished loading: {model_id_or_path}")
return model, tokenizer
print("Pre-loading models...")
model_base, tokenizer_base = None, None
model_ft, tokenizer_ft = None, None
try:
model_base, tokenizer_base = get_model_and_tokenizer(BASE_MODEL_ID)
print("Base model loaded.")
except Exception as e:
print(f"Error loading base model ({BASE_MODEL_ID}): {e}")
try:
model_ft, tokenizer_ft = get_model_and_tokenizer(FINETUNED_MODEL_ID)
print("Fine-tuned model loaded.")
except Exception as e:
print(f"Error loading fine-tuned model ({FINETUNED_MODEL_ID}): {e}")
print("Model pre-loading complete.")
# --- Inference Function (generate_chat_response) ---
# This function remains largely the same as in the previous app.py.
# Make sure it uses `model_base, tokenizer_base` and `model_ft, tokenizer_ft` correctly.
def generate_chat_response(message: str, chat_history: list, model_type: str):
# ... (Keep the exact same generate_chat_response function from the previous app.py)
if model_type == "base":
if model_base is None or tokenizer_base is None:
yield f"Base model ({BASE_MODEL_ID}) is not available."
return
model, tokenizer = model_base, tokenizer_base
system_prompt = SYSTEM_PROMPT_BASE
elif model_type == "finetuned":
if model_ft is None or tokenizer_ft is None:
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) is not available."
return
model, tokenizer = model_ft, tokenizer_ft
system_prompt = SYSTEM_PROMPT_CINEGUIDE
else:
yield "Invalid model type."
return
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user_msg, assistant_msg in chat_history:
if user_msg: # Ensure user_msg is not None
conversation.append({"role": "user", "content": user_msg})
if assistant_msg: # Ensure assistant_msg is not None
conversation.append({"role": "assistant", "content": assistant_msg})
conversation.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
full_response = ""
# Make sure eos_token_id is a list if multiple EOS tokens are possible
eos_tokens_ids = [tokenizer.eos_token_id]
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_id != tokenizer.unk_token_id: # Check if <|im_end|> is in vocab
eos_tokens_ids.append(im_end_id)
generated_token_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id, # Use pad_token_id
eos_token_id=eos_tokens_ids
)
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
response_text = response_text.replace("<|im_end|>", "").strip()
for char in response_text:
full_response += char
time.sleep(0.005)
yield full_response
def respond_base(message, chat_history):
yield from generate_chat_response(message, chat_history, "base")
def respond_finetuned(message, chat_history):
yield from generate_chat_response(message, chat_history, "finetuned")
# --- Gradio UI (with gr.Blocks as demo:) ---
# This part remains largely the same as the previous app.py
# Ensure the Markdown and labels correctly reference the models being loaded.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🎬 CineGuide vs. Base {BASE_MODEL_ID}
Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`)
with the base {BASE_MODEL_ID} model.
Type your movie-related query below and see how each model responds!
"""
)
# ... (Rest of the UI definition: Rows, Columns, Chatbots, Textbox, Button, Examples)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}")
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, bubble_full_width=False)
if model_base is None:
gr.Markdown(f"⚠️ Base model ({BASE_MODEL_ID}) could not be loaded.")
with gr.Column(scale=1):
gr.Markdown(f"## 🤖 Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})")
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, bubble_full_width=False)
if model_ft is None:
gr.Markdown(f"⚠️ Fine-tuned model ({FINETUNED_MODEL_ID}) could not be loaded.")
with gr.Row():
shared_input_textbox = gr.Textbox(
show_label=False,
placeholder="Enter your movie query here and press Enter...",
container=False,
scale=7,
)
submit_button = gr.Button("✉️ Send", variant="primary", scale=1)
gr.Examples(
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick. Think more British comedy style.",
"I'm really into complex sci-fi movies that make you think. I loved Arrival and Blade Runner 2049.",
"I need help planning a family movie night. We have kids aged 8, 11, and 14, plus adults.",
"I'm going through a tough breakup and need something uplifting but not cheesy romantic.",
"I loved Parasite and want to explore more international cinema. Where should I start?",
],
inputs=[shared_input_textbox],
label="Example Prompts (click to use)"
)
def base_model_predict(user_message, chat_history):
if model_base is None: # Add this check
chat_history.append((user_message, f"Base model ({BASE_MODEL_ID}) is not available."))
yield chat_history
return
chat_history.append((user_message, ""))
for response_chunk in respond_base(user_message, chat_history[:-1]):
chat_history[-1] = (user_message, response_chunk)
yield chat_history
def ft_model_predict(user_message, chat_history):
if model_ft is None: # Add this check
chat_history.append((user_message, f"Fine-tuned model ({FINETUNED_MODEL_ID}) is not available."))
yield chat_history
return
chat_history.append((user_message, ""))
for response_chunk in respond_finetuned(user_message, chat_history[:-1]):
chat_history[-1] = (user_message, response_chunk)
yield chat_history
# Event handlers
actions = []
if model_base is not None:
actions.append(
shared_input_textbox.submit(
base_model_predict,
[shared_input_textbox, chatbot_base],
[chatbot_base],
queue=True
)
)
actions.append(
submit_button.click(
base_model_predict,
[shared_input_textbox, chatbot_base],
[chatbot_base],
queue=True
)
)
if model_ft is not None:
actions.append(
shared_input_textbox.submit(
ft_model_predict,
[shared_input_textbox, chatbot_ft],
[chatbot_ft],
queue=True
)
)
actions.append(
submit_button.click(
ft_model_predict,
[shared_input_textbox, chatbot_ft],
[chatbot_ft],
queue=True
)
)
# Clear textbox after all submits are queued. This is slightly simplified.
# For a more robust clear, you might need to chain these events or use gr.Group.
def clear_textbox_fn():
return ""
if actions: # If any model is active
shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox])
submit_button.click(clear_textbox_fn, [], [shared_input_textbox])
# --- Launch the App ---
if __name__ == "__main__":
demo.queue()
demo.launch(debug=True) # share=True for public link if running locally