Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
import time | |
import os # Keep os, it might be useful | |
# --- Configuration --- | |
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" | |
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Assuming this is correct | |
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to: | |
1. Provide personalized movie recommendations based on user preferences | |
2. Give brief, compelling rationales for why you recommend each movie | |
3. Ask thoughtful follow-up questions to better understand user tastes | |
4. Maintain an enthusiastic but not overwhelming tone about cinema | |
When recommending movies, always explain WHY the movie fits their preferences.""" | |
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant." | |
# --- Global Model Cache (models will be loaded on first use) --- | |
_models_cache = { | |
"base": None, | |
"finetuned": None, | |
"tokenizer_base": None, | |
"tokenizer_ft": None, | |
} | |
# --- Model Loading Function (to be called inside decorated functions) --- | |
def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str): | |
"""Loads a model and tokenizer if not already in cache.""" | |
if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None: | |
print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.") | |
return _models_cache[model_key], _models_cache[tokenizer_key] | |
print(f"Loading {model_key} model ({model_identifier})...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_identifier, | |
torch_dtype=torch.bfloat16, # Or torch.float16 if better for available GPU | |
device_map="auto", # This will utilize the GPU allocated by @spaces.GPU | |
trust_remote_code=True, | |
) | |
model.eval() | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
_models_cache[model_key] = model | |
_models_cache[tokenizer_key] = tokenizer | |
print(f"Finished loading and cached {model_key} and {tokenizer_key}.") | |
return model, tokenizer | |
except Exception as e: | |
print(f"ERROR loading {model_key} model ({model_identifier}): {e}") | |
_models_cache[model_key] = "error" # Mark as error to avoid retrying | |
_models_cache[tokenizer_key] = "error" | |
raise # Re-raise the exception to see it in Gradio UI or logs | |
# --- Inference Function (modified to ensure models are loaded) --- | |
def generate_chat_response(message: str, chat_history: list, model_type_to_load: str): | |
model, tokenizer = None, None | |
system_prompt = "" | |
if model_type_to_load == "base": | |
if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error": | |
yield f"Base model ({BASE_MODEL_ID}) failed to load previously." | |
return | |
model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base") | |
system_prompt = SYSTEM_PROMPT_BASE | |
elif model_type_to_load == "finetuned": | |
# Critical check for the FINETUNED_MODEL_ID itself | |
if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str): | |
print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID} (Type: {type(FINETUNED_MODEL_ID)})") | |
yield "Error: Fine-tuned model ID is not configured correctly." | |
return | |
if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error": | |
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously." | |
return | |
model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft") | |
system_prompt = SYSTEM_PROMPT_CINEGUIDE | |
else: | |
yield "Invalid model type." | |
return | |
if model is None or tokenizer is None: # Should be caught by "error" check or exception above | |
yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load." | |
return | |
conversation = [{"role": "system", "content": system_prompt}] if system_prompt else [] | |
conversation.extend(chat_history) # Assuming chat_history is already type="messages" | |
conversation.append({"role": "user", "content": message}) | |
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device) | |
eos_tokens_ids = [tokenizer.eos_token_id] | |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
if im_end_id != getattr(tokenizer, 'unk_token_id', None): # Check if <|im_end|> is in vocab | |
eos_tokens_ids.append(im_end_id) | |
# Remove duplicates just in case eos_token_id is the same as im_end_id | |
eos_tokens_ids = list(set(eos_tokens_ids)) | |
generated_token_ids = model.generate( | |
**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, | |
repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids | |
) | |
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:] | |
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip() | |
full_response = "" | |
for char in response_text: | |
full_response += char | |
time.sleep(0.005) # Adjust for desired speed | |
yield full_response | |
# --- Gradio UI Event Handlers (THESE GET DECORATED) --- | |
# Note: The @spaces.GPU decorator needs to be imported if not automatically available | |
# from huggingface_hub import-like syntax or it might be injected. | |
# For now, let's assume it's magically available in the Space environment. | |
# If not, you might need to find how to import it for ZeroGPU shared pool. | |
# It's often available as: `from Fg.spaces import GPU` and used as `@GPU` | |
# or simply `@spaces.GPU` if `spaces` is an auto-imported object. | |
# Try without explicit import first, as HF might inject it. | |
# If "spaces is not defined" error, you'll need to find the correct import for it. | |
# @spaces.GPU # Placeholder for actual decorator | |
# This is a Gradio decorator for functions, not the HF GPU one. | |
# We need to find the correct HF spaces GPU decorator. | |
# For now, I'll structure as if it exists. | |
# The actual execution of model loading and generation will happen here. | |
# It's common to decorate the function called by the Gradio event. | |
# Let's try decorating the prediction functions. | |
# If `@spaces.GPU` is not found, the app will error earlier. You might need to find its import from HF docs for ZeroGPU. | |
# `from hf_spaces_shared_gpu import gpu_heavy_task` is a made-up example. | |
# Let's assume for now that if the hardware is "ZeroGPU" and this decorator is required, | |
# the Hugging Face platform makes `spaces.GPU` available. | |
def base_model_predict_decorated(user_message, chat_history): | |
# This function will now be responsible for triggering the load and then generating. | |
try: | |
# Model loading now happens here, within the GPU-allocated function | |
# The generate_chat_response will call load_model_and_tokenizer internally if needed | |
bot_response_stream = generate_chat_response(user_message, chat_history, "base") | |
full_bot_message = "" | |
for chunk in bot_response_stream: | |
full_bot_message = chunk | |
yield full_bot_message | |
except Exception as e: | |
print(f"Error in base_model_predict_decorated: {e}") | |
yield f"Error generating base model response: {e}" | |
def ft_model_predict_decorated(user_message, chat_history): | |
try: | |
# Model loading now happens here | |
bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned") | |
full_bot_message = "" | |
for chunk in bot_response_stream: | |
full_bot_message = chunk | |
yield full_bot_message | |
except Exception as e: | |
print(f"Error in ft_model_predict_decorated: {e}") | |
yield f"Error generating fine-tuned response: {e}" | |
# --- Gradio UI Definition --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
f""" | |
# π¬ CineGuide vs. Base {BASE_MODEL_ID} | |
Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`) | |
with the base {BASE_MODEL_ID} model. | |
Type your movie-related query below and see how each model responds! | |
**Note:** Models are loaded on first use and may take some time. Using shared GPU pool. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(f"## π£οΈ Base {BASE_MODEL_ID}") | |
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages") # Use type="messages" | |
with gr.Column(scale=1): | |
gr.Markdown(f"## π€ Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})") | |
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages") # Use type="messages" | |
with gr.Row(): | |
shared_input_textbox = gr.Textbox( | |
show_label=False, placeholder="Enter your movie query...", container=False, scale=7 | |
) | |
submit_button = gr.Button("βοΈ Send", variant="primary", scale=1) | |
gr.Examples( | |
examples=[ | |
"Hi! I'm looking for something funny to watch tonight.", | |
"I love dry, witty humor more than slapstick.", | |
"I'm really into complex sci-fi movies that make you think.", | |
], | |
inputs=[shared_input_textbox], label="Example Prompts" | |
) | |
# Apply the @spaces.GPU decorator if you find the correct way to import/use it. | |
# For now, the functions themselves will handle loading. | |
# If the decorator is `@spaces.GPU()`, it would be: | |
# submit_button.click(spaces.GPU()(base_model_predict_decorated), ...) | |
# This part is tricky without knowing the exact decorator syntax for ZeroGPU. | |
# Let's assume the functions are called and *they* handle the GPU context internally. | |
# If the platform *requires* the event handler itself to be decorated, that's a different structure. | |
# The functions `base_model_predict_decorated` and `ft_model_predict_decorated` | |
# are what Gradio will call. If these need the `@spaces.GPU` decorator, you'd apply it like: | |
# @spaces.GPU | |
# def decorated_base_predict(user_message, chat_history): | |
# yield from base_model_predict_decorated(user_message, chat_history) | |
# And then pass `decorated_base_predict` to `submit_button.click` | |
# Simpler approach for now: let Gradio call these directly. | |
# If a wrapper is needed for the decorator, we can add it. | |
submit_button.click( | |
base_model_predict_decorated, | |
[shared_input_textbox, chatbot_base], | |
[chatbot_base], | |
# api_name="base_predict" # Optional | |
) | |
submit_button.click( | |
ft_model_predict_decorated, | |
[shared_input_textbox, chatbot_ft], | |
[chatbot_ft], | |
# api_name="ft_predict" # Optional | |
) | |
# Handle textbox submit event for both | |
shared_input_textbox.submit( | |
base_model_predict_decorated, | |
[shared_input_textbox, chatbot_base], | |
[chatbot_base] | |
) | |
shared_input_textbox.submit( | |
ft_model_predict_decorated, | |
[shared_input_textbox, chatbot_ft], | |
[chatbot_ft] | |
) | |
def clear_textbox_fn(): return "" | |
submit_button.click(clear_textbox_fn, [], [shared_input_textbox]) | |
shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox]) | |
if __name__ == "__main__": | |
# The following line is usually specific to certain Space configurations. | |
# For ZeroGPU with @spaces.GPU, this might be needed in the README.md/config.yaml | |
# rather than here, or the decorator itself implies it. | |
# demo.config(dependencies=["torch", "transformers", "accelerate", ...]) | |
# Check Gradio docs for how to make a function eligible for @spaces.GPU if it's not a direct event handler. | |
# Often, the main event handler itself is decorated. | |
demo.queue() | |
demo.launch(debug=True) |