serhany's picture
Upload folder using huggingface_hub
52a6eee verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import time
import os
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
# Path to your merged fine-tuned model within the Hugging Face Space
# If 'cineguide-merged' is at the root of your Space repo:
FINETUNED_MODEL_PATH = "cineguide-merged"
# System prompts
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema
When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
# --- Model Loading ---
# Cache models globally so they are loaded only once
_models_cache = {}
def get_model_and_tokenizer(model_id_or_path):
if model_id_or_path in _models_cache:
return _models_cache[model_id_or_path]
print(f"Loading model: {model_id_or_path}")
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
torch_dtype=torch.bfloat16, # Use bfloat16 for faster inference
device_map="auto", # Automatically distribute across GPUs if available
trust_remote_code=True,
# attn_implementation="flash_attention_2" # Optional: if supported by Space hardware & transformers version
)
model.eval() # Set to evaluation mode
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
_models_cache[model_id_or_path] = (model, tokenizer)
print(f"Finished loading: {model_id_or_path}")
return model, tokenizer
# Pre-load models when the script starts
# This can take time, so Gradio might show a loading screen.
# For Spaces, this happens during the build/startup phase.
print("Pre-loading models...")
try:
model_base, tokenizer_base = get_model_and_tokenizer(BASE_MODEL_ID)
print("Base model loaded.")
except Exception as e:
print(f"Error loading base model: {e}")
model_base, tokenizer_base = None, None
# Check if fine-tuned model path exists before loading
if os.path.exists(FINETUNED_MODEL_PATH) and os.path.isdir(FINETUNED_MODEL_PATH):
try:
model_ft, tokenizer_ft = get_model_and_tokenizer(FINETUNED_MODEL_PATH)
print("Fine-tuned model loaded.")
except Exception as e:
print(f"Error loading fine-tuned model from {FINETUNED_MODEL_PATH}: {e}")
model_ft, tokenizer_ft = None, None
else:
print(f"Fine-tuned model path not found: {FINETUNED_MODEL_PATH}. Skipping fine-tuned model.")
model_ft, tokenizer_ft = None, None
print("Model pre-loading complete.")
# --- Inference Function ---
def generate_chat_response(message: str, chat_history: list, model_type: str):
if model_type == "base":
model, tokenizer = model_base, tokenizer_base
system_prompt = SYSTEM_PROMPT_BASE
elif model_type == "finetuned":
model, tokenizer = model_ft, tokenizer_ft
system_prompt = SYSTEM_PROMPT_CINEGUIDE
else:
yield "Invalid model type."
return
if model is None or tokenizer is None:
yield f"Model '{model_type}' is not available."
return
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user_msg, assistant_msg in chat_history:
conversation.append({"role": "user", "content": user_msg})
conversation.append({"role": "assistant", "content": assistant_msg})
conversation.append({"role": "user", "content": message})
# Apply chat template
prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True # This adds the <|im_start|>assistant prefix
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# For streaming, run generation in a separate thread
# For Gradio, we can yield partial results
# However, TextStreamer prints to stdout. For Gradio, we need to capture.
# Simpler non-streaming approach for direct yield:
# Remove streamer from generation_kwargs
# outputs = model.generate(**generation_kwargs_without_streamer)
# decoded_output = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# yield decoded_output
# More complex streaming for Gradio:
full_response = ""
generated_token_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>")]
)
# Decode only the newly generated tokens
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
response_text = response_text.replace("<|im_end|>", "").strip() # Clean up
# Yield character by character for streaming effect (can be slow for long responses)
# A better way is to yield chunks. For simplicity, this is char by char.
for char in response_text:
full_response += char
time.sleep(0.005) # Adjust for desired speed
yield full_response
def respond_base(message, chat_history):
# chat_history is a list of [user_msg, assistant_msg]
yield from generate_chat_response(message, chat_history, "base")
def respond_finetuned(message, chat_history):
yield from generate_chat_response(message, chat_history, "finetuned")
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ๐ŸŽฌ CineGuide vs. Base Qwen2.5-7B-Instruct
Compare the fine-tuned CineGuide movie recommender with the base Qwen2.5-7B-Instruct model.
Type your movie-related query below and see how each model responds!
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## ๐Ÿ—ฃ๏ธ Base Qwen2.5-7B-Instruct")
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, bubble_full_width=False)
if model_base is None:
gr.Markdown("โš ๏ธ Base model could not be loaded. This chat interface will not work.")
with gr.Column(scale=1):
gr.Markdown("## ๐Ÿค– Fine-tuned CineGuide (Qwen2.5-7B)")
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, bubble_full_width=False)
if model_ft is None:
gr.Markdown("โš ๏ธ Fine-tuned model could not be loaded. This chat interface will not work.")
with gr.Row():
shared_input_textbox = gr.Textbox(
show_label=False,
placeholder="Enter your movie query here and press Enter...",
container=False,
scale=7, # Make it wider
)
submit_button = gr.Button("โœ‰๏ธ Send", variant="primary", scale=1)
# clear_button = gr.Button("๐Ÿ—‘๏ธ Clear All", scale=1) # If you want a single clear button
# Predefined examples
gr.Examples(
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick. Think more British comedy style.",
"I'm really into complex sci-fi movies that make you think. I loved Arrival and Blade Runner 2049.",
"I need help planning a family movie night. We have kids aged 8, 11, and 14, plus adults.",
"I'm going through a tough breakup and need something uplifting but not cheesy romantic.",
"I loved Parasite and want to explore more international cinema. Where should I start?",
],
inputs=[shared_input_textbox],
# outputs=[chatbot_base, chatbot_ft], # Examples don't directly populate chatbots
# fn=lambda x: (None, None), # Dummy function for examples
label="Example Prompts (click to use)"
)
# Event handlers
def handle_submit(user_message, chat_history_base, chat_history_ft):
# This will return iterators. Gradio handles them for streaming.
# Important: chat_history is updated by Gradio automatically by returning (user_message, bot_message_chunk)
# For simultaneous updates, we need to manage history carefully or use a trick.
# Gradio's chatbot expects the history list to be updated.
# The `respond_base` and `respond_finetuned` functions already take history.
# The issue is that Gradio wants a function that returns the new state of the chatbot.
# Simplest for simultaneous: return None for the other chatbot if we trigger one by one.
# For true simultaneous, you'd need a more complex setup or separate submit buttons.
# Let's make them update sequentially for simplicity with one input.
# Update base model chat
chat_history_base.append((user_message, None)) # Add user message
# The `yield` from respond_base will update the last message (None)
# Update fine-tuned model chat
chat_history_ft.append((user_message, None)) # Add user message
# We need to return generators that Gradio can iterate over
# This won't work directly as Gradio expects outputs to be bound to specific components.
# We need to make the function return the new state for *both* chatbots.
# The `respond_base` and `respond_finetuned` should update their respective histories.
# Gradio's Chatbot expects (message, history) -> history or (message, history) -> yield history_updates
# Let's define wrapper functions for the submit action.
return "", chat_history_base, chat_history_ft # Clear textbox, pass history
def base_model_predict(user_message, chat_history):
chat_history.append((user_message, "")) # Add user message and placeholder for bot
for response_chunk in respond_base(user_message, chat_history[:-1]): # Pass history without current turn
chat_history[-1] = (user_message, response_chunk)
yield chat_history
def ft_model_predict(user_message, chat_history):
chat_history.append((user_message, ""))
for response_chunk in respond_finetuned(user_message, chat_history[:-1]):
chat_history[-1] = (user_message, response_chunk)
yield chat_history
# When shared_input_textbox is submitted or submit_button is clicked:
if model_base is not None:
shared_input_textbox.submit(
base_model_predict,
[shared_input_textbox, chatbot_base],
[chatbot_base],
)
submit_button.click(
base_model_predict,
[shared_input_textbox, chatbot_base],
[chatbot_base],
)
if model_ft is not None:
shared_input_textbox.submit(
ft_model_predict,
[shared_input_textbox, chatbot_ft],
[chatbot_ft],
)
submit_button.click(
ft_model_predict,
[shared_input_textbox, chatbot_ft],
[chatbot_ft],
)
# After both predictions are done (or if one is skipped), clear the input textbox
# This is a bit tricky with simultaneous submits.
# A simpler way is to clear it on the second submit if both models are active.
# Or, let Gradio handle textbox clearing by returning "" as the first element of the outputs list.
# If ft_model_predict is the last one to be called from submit:
if model_ft is not None:
shared_input_textbox.submit(lambda: "", [], [shared_input_textbox])
submit_button.click(lambda: "", [], [shared_input_textbox])
elif model_base is not None: # If only base model is active
shared_input_textbox.submit(lambda: "", [], [shared_input_textbox])
submit_button.click(lambda: "", [], [shared_input_textbox])
# Clear buttons (Individual)
# clear_base_btn = gr.Button("๐Ÿ—‘๏ธ Clear Base Chat")
# clear_ft_btn = gr.Button("๐Ÿ—‘๏ธ Clear CineGuide Chat")
# clear_base_btn.click(lambda: (None, ""), None, [chatbot_base, shared_input_textbox], queue=False)
# clear_ft_btn.click(lambda: (None, ""), None, [chatbot_ft, shared_input_textbox], queue=False)
# --- Launch the App ---
if __name__ == "__main__":
demo.queue() # Enable queuing for handling multiple users
demo.launch(debug=True, share=False) # share=True for public link if running locally