|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
|
import time |
|
import os |
|
|
|
|
|
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" |
|
|
|
|
|
FINETUNED_MODEL_PATH = "cineguide-merged" |
|
|
|
|
|
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: |
|
1. Provide personalized movie recommendations based on user preferences |
|
2. Give brief, compelling rationales for why you recommend each movie |
|
3. Ask thoughtful follow-up questions to better understand user tastes |
|
4. Maintain an enthusiastic but not overwhelming tone about cinema |
|
|
|
When recommending movies, always explain WHY the movie fits their preferences.""" |
|
|
|
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." |
|
|
|
|
|
|
|
_models_cache = {} |
|
|
|
def get_model_and_tokenizer(model_id_or_path): |
|
if model_id_or_path in _models_cache: |
|
return _models_cache[model_id_or_path] |
|
|
|
print(f"Loading model: {model_id_or_path}") |
|
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id_or_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
|
|
) |
|
model.eval() |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
_models_cache[model_id_or_path] = (model, tokenizer) |
|
print(f"Finished loading: {model_id_or_path}") |
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
print("Pre-loading models...") |
|
try: |
|
model_base, tokenizer_base = get_model_and_tokenizer(BASE_MODEL_ID) |
|
print("Base model loaded.") |
|
except Exception as e: |
|
print(f"Error loading base model: {e}") |
|
model_base, tokenizer_base = None, None |
|
|
|
|
|
if os.path.exists(FINETUNED_MODEL_PATH) and os.path.isdir(FINETUNED_MODEL_PATH): |
|
try: |
|
model_ft, tokenizer_ft = get_model_and_tokenizer(FINETUNED_MODEL_PATH) |
|
print("Fine-tuned model loaded.") |
|
except Exception as e: |
|
print(f"Error loading fine-tuned model from {FINETUNED_MODEL_PATH}: {e}") |
|
model_ft, tokenizer_ft = None, None |
|
else: |
|
print(f"Fine-tuned model path not found: {FINETUNED_MODEL_PATH}. Skipping fine-tuned model.") |
|
model_ft, tokenizer_ft = None, None |
|
print("Model pre-loading complete.") |
|
|
|
|
|
|
|
def generate_chat_response(message: str, chat_history: list, model_type: str): |
|
if model_type == "base": |
|
model, tokenizer = model_base, tokenizer_base |
|
system_prompt = SYSTEM_PROMPT_BASE |
|
elif model_type == "finetuned": |
|
model, tokenizer = model_ft, tokenizer_ft |
|
system_prompt = SYSTEM_PROMPT_CINEGUIDE |
|
else: |
|
yield "Invalid model type." |
|
return |
|
|
|
if model is None or tokenizer is None: |
|
yield f"Model '{model_type}' is not available." |
|
return |
|
|
|
conversation = [] |
|
if system_prompt: |
|
conversation.append({"role": "system", "content": system_prompt}) |
|
|
|
for user_msg, assistant_msg in chat_history: |
|
conversation.append({"role": "user", "content": user_msg}) |
|
conversation.append({"role": "assistant", "content": assistant_msg}) |
|
conversation.append({"role": "user", "content": message}) |
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
conversation, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) |
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = dict( |
|
inputs, |
|
streamer=streamer, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
repetition_penalty=1.1, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_response = "" |
|
generated_token_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
repetition_penalty=1.1, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>")] |
|
) |
|
|
|
|
|
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] |
|
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
response_text = response_text.replace("<|im_end|>", "").strip() |
|
|
|
|
|
|
|
for char in response_text: |
|
full_response += char |
|
time.sleep(0.005) |
|
yield full_response |
|
|
|
|
|
def respond_base(message, chat_history): |
|
|
|
yield from generate_chat_response(message, chat_history, "base") |
|
|
|
def respond_finetuned(message, chat_history): |
|
yield from generate_chat_response(message, chat_history, "finetuned") |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# ๐ฌ CineGuide vs. Base Qwen2.5-7B-Instruct |
|
Compare the fine-tuned CineGuide movie recommender with the base Qwen2.5-7B-Instruct model. |
|
Type your movie-related query below and see how each model responds! |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("## ๐ฃ๏ธ Base Qwen2.5-7B-Instruct") |
|
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, bubble_full_width=False) |
|
if model_base is None: |
|
gr.Markdown("โ ๏ธ Base model could not be loaded. This chat interface will not work.") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## ๐ค Fine-tuned CineGuide (Qwen2.5-7B)") |
|
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, bubble_full_width=False) |
|
if model_ft is None: |
|
gr.Markdown("โ ๏ธ Fine-tuned model could not be loaded. This chat interface will not work.") |
|
|
|
with gr.Row(): |
|
shared_input_textbox = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter your movie query here and press Enter...", |
|
container=False, |
|
scale=7, |
|
) |
|
submit_button = gr.Button("โ๏ธ Send", variant="primary", scale=1) |
|
|
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"Hi! I'm looking for something funny to watch tonight.", |
|
"I love dry, witty humor more than slapstick. Think more British comedy style.", |
|
"I'm really into complex sci-fi movies that make you think. I loved Arrival and Blade Runner 2049.", |
|
"I need help planning a family movie night. We have kids aged 8, 11, and 14, plus adults.", |
|
"I'm going through a tough breakup and need something uplifting but not cheesy romantic.", |
|
"I loved Parasite and want to explore more international cinema. Where should I start?", |
|
], |
|
inputs=[shared_input_textbox], |
|
|
|
|
|
label="Example Prompts (click to use)" |
|
) |
|
|
|
|
|
def handle_submit(user_message, chat_history_base, chat_history_ft): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_history_base.append((user_message, None)) |
|
|
|
|
|
|
|
chat_history_ft.append((user_message, None)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return "", chat_history_base, chat_history_ft |
|
|
|
def base_model_predict(user_message, chat_history): |
|
chat_history.append((user_message, "")) |
|
for response_chunk in respond_base(user_message, chat_history[:-1]): |
|
chat_history[-1] = (user_message, response_chunk) |
|
yield chat_history |
|
|
|
def ft_model_predict(user_message, chat_history): |
|
chat_history.append((user_message, "")) |
|
for response_chunk in respond_finetuned(user_message, chat_history[:-1]): |
|
chat_history[-1] = (user_message, response_chunk) |
|
yield chat_history |
|
|
|
|
|
if model_base is not None: |
|
shared_input_textbox.submit( |
|
base_model_predict, |
|
[shared_input_textbox, chatbot_base], |
|
[chatbot_base], |
|
) |
|
submit_button.click( |
|
base_model_predict, |
|
[shared_input_textbox, chatbot_base], |
|
[chatbot_base], |
|
) |
|
|
|
if model_ft is not None: |
|
shared_input_textbox.submit( |
|
ft_model_predict, |
|
[shared_input_textbox, chatbot_ft], |
|
[chatbot_ft], |
|
) |
|
submit_button.click( |
|
ft_model_predict, |
|
[shared_input_textbox, chatbot_ft], |
|
[chatbot_ft], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_ft is not None: |
|
shared_input_textbox.submit(lambda: "", [], [shared_input_textbox]) |
|
submit_button.click(lambda: "", [], [shared_input_textbox]) |
|
elif model_base is not None: |
|
shared_input_textbox.submit(lambda: "", [], [shared_input_textbox]) |
|
submit_button.click(lambda: "", [], [shared_input_textbox]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch(debug=True, share=False) |