Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
import time | |
import os | |
# --- Configuration --- | |
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" | |
# NOW, this points to your model on the Hugging Face Hub | |
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" | |
# System prompts (same as before) | |
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: | |
1. Provide personalized movie recommendations based on user preferences | |
2. Give brief, compelling rationales for why you recommend each movie | |
3. Ask thoughtful follow-up questions to better understand user tastes | |
4. Maintain an enthusiastic but not overwhelming tone about cinema | |
When recommending movies, always explain WHY the movie fits their preferences.""" | |
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." | |
# --- Model Loading --- | |
_models_cache = {} | |
def get_model_and_tokenizer(model_id_or_path, is_local_path=False): # Added is_local_path for flexibility | |
if model_id_or_path in _models_cache: | |
return _models_cache[model_id_or_path] | |
print(f"Loading model: {model_id_or_path}") | |
# For models from Hub, trust_remote_code is often needed for custom architectures like Qwen | |
# For local paths, it might also be needed if they were saved with trust_remote_code=True | |
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id_or_path, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True, | |
# attn_implementation="flash_attention_2" # Optional | |
) | |
model.eval() | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Ensure pad_token_id is also set if pad_token is set | |
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
_models_cache[model_id_or_path] = (model, tokenizer) | |
print(f"Finished loading: {model_id_or_path}") | |
return model, tokenizer | |
print("Pre-loading models...") | |
model_base, tokenizer_base = None, None | |
model_ft, tokenizer_ft = None, None | |
try: | |
model_base, tokenizer_base = get_model_and_tokenizer(BASE_MODEL_ID) | |
print("Base model loaded.") | |
except Exception as e: | |
print(f"Error loading base model ({BASE_MODEL_ID}): {e}") | |
try: | |
model_ft, tokenizer_ft = get_model_and_tokenizer(FINETUNED_MODEL_ID) | |
print("Fine-tuned model loaded.") | |
except Exception as e: | |
print(f"Error loading fine-tuned model ({FINETUNED_MODEL_ID}): {e}") | |
print("Model pre-loading complete.") | |
# --- Inference Function (generate_chat_response) --- | |
# This function remains largely the same as in the previous app.py. | |
# Make sure it uses `model_base, tokenizer_base` and `model_ft, tokenizer_ft` correctly. | |
def generate_chat_response(message: str, chat_history: list, model_type: str): | |
# ... (Keep the exact same generate_chat_response function from the previous app.py) | |
if model_type == "base": | |
if model_base is None or tokenizer_base is None: | |
yield f"Base model ({BASE_MODEL_ID}) is not available." | |
return | |
model, tokenizer = model_base, tokenizer_base | |
system_prompt = SYSTEM_PROMPT_BASE | |
elif model_type == "finetuned": | |
if model_ft is None or tokenizer_ft is None: | |
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) is not available." | |
return | |
model, tokenizer = model_ft, tokenizer_ft | |
system_prompt = SYSTEM_PROMPT_CINEGUIDE | |
else: | |
yield "Invalid model type." | |
return | |
conversation = [] | |
if system_prompt: | |
conversation.append({"role": "system", "content": system_prompt}) | |
for user_msg, assistant_msg in chat_history: | |
if user_msg: # Ensure user_msg is not None | |
conversation.append({"role": "user", "content": user_msg}) | |
if assistant_msg: # Ensure assistant_msg is not None | |
conversation.append({"role": "assistant", "content": assistant_msg}) | |
conversation.append({"role": "user", "content": message}) | |
prompt = tokenizer.apply_chat_template( | |
conversation, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) | |
full_response = "" | |
# Make sure eos_token_id is a list if multiple EOS tokens are possible | |
eos_tokens_ids = [tokenizer.eos_token_id] | |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
if im_end_id != tokenizer.unk_token_id: # Check if <|im_end|> is in vocab | |
eos_tokens_ids.append(im_end_id) | |
generated_token_ids = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
pad_token_id=tokenizer.pad_token_id, # Use pad_token_id | |
eos_token_id=eos_tokens_ids | |
) | |
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] | |
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
response_text = response_text.replace("<|im_end|>", "").strip() | |
for char in response_text: | |
full_response += char | |
time.sleep(0.005) | |
yield full_response | |
def respond_base(message, chat_history): | |
yield from generate_chat_response(message, chat_history, "base") | |
def respond_finetuned(message, chat_history): | |
yield from generate_chat_response(message, chat_history, "finetuned") | |
# --- Gradio UI (with gr.Blocks as demo:) --- | |
# This part remains largely the same as the previous app.py | |
# Ensure the Markdown and labels correctly reference the models being loaded. | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
f""" | |
# 🎬 CineGuide vs. Base {BASE_MODEL_ID} | |
Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`) | |
with the base {BASE_MODEL_ID} model. | |
Type your movie-related query below and see how each model responds! | |
""" | |
) | |
# ... (Rest of the UI definition: Rows, Columns, Chatbots, Textbox, Button, Examples) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}") | |
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, bubble_full_width=False) | |
if model_base is None: | |
gr.Markdown(f"⚠️ Base model ({BASE_MODEL_ID}) could not be loaded.") | |
with gr.Column(scale=1): | |
gr.Markdown(f"## 🤖 Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})") | |
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, bubble_full_width=False) | |
if model_ft is None: | |
gr.Markdown(f"⚠️ Fine-tuned model ({FINETUNED_MODEL_ID}) could not be loaded.") | |
with gr.Row(): | |
shared_input_textbox = gr.Textbox( | |
show_label=False, | |
placeholder="Enter your movie query here and press Enter...", | |
container=False, | |
scale=7, | |
) | |
submit_button = gr.Button("✉️ Send", variant="primary", scale=1) | |
gr.Examples( | |
examples=[ | |
"Hi! I'm looking for something funny to watch tonight.", | |
"I love dry, witty humor more than slapstick. Think more British comedy style.", | |
"I'm really into complex sci-fi movies that make you think. I loved Arrival and Blade Runner 2049.", | |
"I need help planning a family movie night. We have kids aged 8, 11, and 14, plus adults.", | |
"I'm going through a tough breakup and need something uplifting but not cheesy romantic.", | |
"I loved Parasite and want to explore more international cinema. Where should I start?", | |
], | |
inputs=[shared_input_textbox], | |
label="Example Prompts (click to use)" | |
) | |
def base_model_predict(user_message, chat_history): | |
if model_base is None: # Add this check | |
chat_history.append((user_message, f"Base model ({BASE_MODEL_ID}) is not available.")) | |
yield chat_history | |
return | |
chat_history.append((user_message, "")) | |
for response_chunk in respond_base(user_message, chat_history[:-1]): | |
chat_history[-1] = (user_message, response_chunk) | |
yield chat_history | |
def ft_model_predict(user_message, chat_history): | |
if model_ft is None: # Add this check | |
chat_history.append((user_message, f"Fine-tuned model ({FINETUNED_MODEL_ID}) is not available.")) | |
yield chat_history | |
return | |
chat_history.append((user_message, "")) | |
for response_chunk in respond_finetuned(user_message, chat_history[:-1]): | |
chat_history[-1] = (user_message, response_chunk) | |
yield chat_history | |
# Event handlers | |
actions = [] | |
if model_base is not None: | |
actions.append( | |
shared_input_textbox.submit( | |
base_model_predict, | |
[shared_input_textbox, chatbot_base], | |
[chatbot_base], | |
queue=True | |
) | |
) | |
actions.append( | |
submit_button.click( | |
base_model_predict, | |
[shared_input_textbox, chatbot_base], | |
[chatbot_base], | |
queue=True | |
) | |
) | |
if model_ft is not None: | |
actions.append( | |
shared_input_textbox.submit( | |
ft_model_predict, | |
[shared_input_textbox, chatbot_ft], | |
[chatbot_ft], | |
queue=True | |
) | |
) | |
actions.append( | |
submit_button.click( | |
ft_model_predict, | |
[shared_input_textbox, chatbot_ft], | |
[chatbot_ft], | |
queue=True | |
) | |
) | |
# Clear textbox after all submits are queued. This is slightly simplified. | |
# For a more robust clear, you might need to chain these events or use gr.Group. | |
def clear_textbox_fn(): | |
return "" | |
if actions: # If any model is active | |
shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox]) | |
submit_button.click(clear_textbox_fn, [], [shared_input_textbox]) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch(debug=True) # share=True for public link if running locally |