Spaces:
Sleeping
Sleeping
File size: 12,468 Bytes
0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc 0c08550 18449fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import time
import os # Keep os, it might be useful
# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Assuming this is correct
SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema
When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
# --- Global Model Cache (models will be loaded on first use) ---
_models_cache = {
"base": None,
"finetuned": None,
"tokenizer_base": None,
"tokenizer_ft": None,
}
# --- Model Loading Function (to be called inside decorated functions) ---
def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
"""Loads a model and tokenizer if not already in cache."""
if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
return _models_cache[model_key], _models_cache[tokenizer_key]
print(f"Loading {model_key} model ({model_identifier})...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_identifier,
torch_dtype=torch.bfloat16, # Or torch.float16 if better for available GPU
device_map="auto", # This will utilize the GPU allocated by @spaces.GPU
trust_remote_code=True,
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
_models_cache[model_key] = model
_models_cache[tokenizer_key] = tokenizer
print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
return model, tokenizer
except Exception as e:
print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
_models_cache[model_key] = "error" # Mark as error to avoid retrying
_models_cache[tokenizer_key] = "error"
raise # Re-raise the exception to see it in Gradio UI or logs
# --- Inference Function (modified to ensure models are loaded) ---
def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
model, tokenizer = None, None
system_prompt = ""
if model_type_to_load == "base":
if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
system_prompt = SYSTEM_PROMPT_BASE
elif model_type_to_load == "finetuned":
# Critical check for the FINETUNED_MODEL_ID itself
if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID} (Type: {type(FINETUNED_MODEL_ID)})")
yield "Error: Fine-tuned model ID is not configured correctly."
return
if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
return
model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
system_prompt = SYSTEM_PROMPT_CINEGUIDE
else:
yield "Invalid model type."
return
if model is None or tokenizer is None: # Should be caught by "error" check or exception above
yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
return
conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
conversation.extend(chat_history) # Assuming chat_history is already type="messages"
conversation.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
eos_tokens_ids = [tokenizer.eos_token_id]
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_id != getattr(tokenizer, 'unk_token_id', None): # Check if <|im_end|> is in vocab
eos_tokens_ids.append(im_end_id)
# Remove duplicates just in case eos_token_id is the same as im_end_id
eos_tokens_ids = list(set(eos_tokens_ids))
generated_token_ids = model.generate(
**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9,
repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids
)
new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()
full_response = ""
for char in response_text:
full_response += char
time.sleep(0.005) # Adjust for desired speed
yield full_response
# --- Gradio UI Event Handlers (THESE GET DECORATED) ---
# Note: The @spaces.GPU decorator needs to be imported if not automatically available
# from huggingface_hub import-like syntax or it might be injected.
# For now, let's assume it's magically available in the Space environment.
# If not, you might need to find how to import it for ZeroGPU shared pool.
# It's often available as: `from Fg.spaces import GPU` and used as `@GPU`
# or simply `@spaces.GPU` if `spaces` is an auto-imported object.
# Try without explicit import first, as HF might inject it.
# If "spaces is not defined" error, you'll need to find the correct import for it.
# @spaces.GPU # Placeholder for actual decorator
@gr.्रु # This is a Gradio decorator for functions, not the HF GPU one.
# We need to find the correct HF spaces GPU decorator.
# For now, I'll structure as if it exists.
# The actual execution of model loading and generation will happen here.
# It's common to decorate the function called by the Gradio event.
# Let's try decorating the prediction functions.
# If `@spaces.GPU` is not found, the app will error earlier. You might need to find its import from HF docs for ZeroGPU.
# `from hf_spaces_shared_gpu import gpu_heavy_task` is a made-up example.
# Let's assume for now that if the hardware is "ZeroGPU" and this decorator is required,
# the Hugging Face platform makes `spaces.GPU` available.
def base_model_predict_decorated(user_message, chat_history):
# This function will now be responsible for triggering the load and then generating.
try:
# Model loading now happens here, within the GPU-allocated function
# The generate_chat_response will call load_model_and_tokenizer internally if needed
bot_response_stream = generate_chat_response(user_message, chat_history, "base")
full_bot_message = ""
for chunk in bot_response_stream:
full_bot_message = chunk
yield full_bot_message
except Exception as e:
print(f"Error in base_model_predict_decorated: {e}")
yield f"Error generating base model response: {e}"
def ft_model_predict_decorated(user_message, chat_history):
try:
# Model loading now happens here
bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
full_bot_message = ""
for chunk in bot_response_stream:
full_bot_message = chunk
yield full_bot_message
except Exception as e:
print(f"Error in ft_model_predict_decorated: {e}")
yield f"Error generating fine-tuned response: {e}"
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🎬 CineGuide vs. Base {BASE_MODEL_ID}
Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`)
with the base {BASE_MODEL_ID} model.
Type your movie-related query below and see how each model responds!
**Note:** Models are loaded on first use and may take some time. Using shared GPU pool.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}")
chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages") # Use type="messages"
with gr.Column(scale=1):
gr.Markdown(f"## 🤖 Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})")
chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages") # Use type="messages"
with gr.Row():
shared_input_textbox = gr.Textbox(
show_label=False, placeholder="Enter your movie query...", container=False, scale=7
)
submit_button = gr.Button("✉️ Send", variant="primary", scale=1)
gr.Examples(
examples=[
"Hi! I'm looking for something funny to watch tonight.",
"I love dry, witty humor more than slapstick.",
"I'm really into complex sci-fi movies that make you think.",
],
inputs=[shared_input_textbox], label="Example Prompts"
)
# Apply the @spaces.GPU decorator if you find the correct way to import/use it.
# For now, the functions themselves will handle loading.
# If the decorator is `@spaces.GPU()`, it would be:
# submit_button.click(spaces.GPU()(base_model_predict_decorated), ...)
# This part is tricky without knowing the exact decorator syntax for ZeroGPU.
# Let's assume the functions are called and *they* handle the GPU context internally.
# If the platform *requires* the event handler itself to be decorated, that's a different structure.
# The functions `base_model_predict_decorated` and `ft_model_predict_decorated`
# are what Gradio will call. If these need the `@spaces.GPU` decorator, you'd apply it like:
# @spaces.GPU
# def decorated_base_predict(user_message, chat_history):
# yield from base_model_predict_decorated(user_message, chat_history)
# And then pass `decorated_base_predict` to `submit_button.click`
# Simpler approach for now: let Gradio call these directly.
# If a wrapper is needed for the decorator, we can add it.
submit_button.click(
base_model_predict_decorated,
[shared_input_textbox, chatbot_base],
[chatbot_base],
# api_name="base_predict" # Optional
)
submit_button.click(
ft_model_predict_decorated,
[shared_input_textbox, chatbot_ft],
[chatbot_ft],
# api_name="ft_predict" # Optional
)
# Handle textbox submit event for both
shared_input_textbox.submit(
base_model_predict_decorated,
[shared_input_textbox, chatbot_base],
[chatbot_base]
)
shared_input_textbox.submit(
ft_model_predict_decorated,
[shared_input_textbox, chatbot_ft],
[chatbot_ft]
)
def clear_textbox_fn(): return ""
submit_button.click(clear_textbox_fn, [], [shared_input_textbox])
shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox])
if __name__ == "__main__":
# The following line is usually specific to certain Space configurations.
# For ZeroGPU with @spaces.GPU, this might be needed in the README.md/config.yaml
# rather than here, or the decorator itself implies it.
# demo.config(dependencies=["torch", "transformers", "accelerate", ...])
# Check Gradio docs for how to make a function eligible for @spaces.GPU if it's not a direct event handler.
# Often, the main event handler itself is decorated.
demo.queue()
demo.launch(debug=True) |