File size: 12,468 Bytes
0c08550
 
 
 
18449fc
0c08550
 
 
18449fc
0c08550
 
 
 
 
 
 
 
 
 
18449fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
18449fc
0c08550
18449fc
 
 
 
 
 
 
 
0c08550
18449fc
0c08550
 
 
 
 
18449fc
 
 
0c08550
18449fc
 
0c08550
 
18449fc
0c08550
18449fc
0c08550
 
18449fc
0c08550
18449fc
 
0c08550
 
 
18449fc
 
0c08550
 
18449fc
0c08550
18449fc
0c08550
 
18449fc
0c08550
 
 
18449fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
 
 
 
 
 
 
18449fc
0c08550
 
 
 
 
18449fc
0c08550
 
18449fc
0c08550
 
 
18449fc
0c08550
 
 
 
 
 
18449fc
 
0c08550
18449fc
0c08550
 
18449fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
18449fc
 
 
0c08550
 
 
18449fc
 
 
 
 
 
 
0c08550
18449fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import time
import os # Keep os, it might be useful

# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft" # Assuming this is correct

SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema

When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."

# --- Global Model Cache (models will be loaded on first use) ---
_models_cache = {
    "base": None,
    "finetuned": None,
    "tokenizer_base": None,
    "tokenizer_ft": None,
}

# --- Model Loading Function (to be called inside decorated functions) ---
def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
    """Loads a model and tokenizer if not already in cache."""
    if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
        print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
        return _models_cache[model_key], _models_cache[tokenizer_key]

    print(f"Loading {model_key} model ({model_identifier})...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_identifier,
            torch_dtype=torch.bfloat16, # Or torch.float16 if better for available GPU
            device_map="auto", # This will utilize the GPU allocated by @spaces.GPU
            trust_remote_code=True,
        )
        model.eval()

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
        
        _models_cache[model_key] = model
        _models_cache[tokenizer_key] = tokenizer
        print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
        return model, tokenizer
    except Exception as e:
        print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
        _models_cache[model_key] = "error" # Mark as error to avoid retrying
        _models_cache[tokenizer_key] = "error"
        raise # Re-raise the exception to see it in Gradio UI or logs

# --- Inference Function (modified to ensure models are loaded) ---
def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
    model, tokenizer = None, None
    system_prompt = ""

    if model_type_to_load == "base":
        if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
            yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
        system_prompt = SYSTEM_PROMPT_BASE
    elif model_type_to_load == "finetuned":
        # Critical check for the FINETUNED_MODEL_ID itself
        if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
            print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID} (Type: {type(FINETUNED_MODEL_ID)})")
            yield "Error: Fine-tuned model ID is not configured correctly."
            return
        if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
            yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
        system_prompt = SYSTEM_PROMPT_CINEGUIDE
    else:
        yield "Invalid model type."
        return

    if model is None or tokenizer is None: # Should be caught by "error" check or exception above
        yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
        return

    conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
    conversation.extend(chat_history) # Assuming chat_history is already type="messages"
    conversation.append({"role": "user", "content": message})

    prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
    
    eos_tokens_ids = [tokenizer.eos_token_id]
    im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    if im_end_id != getattr(tokenizer, 'unk_token_id', None): # Check if <|im_end|> is in vocab
        eos_tokens_ids.append(im_end_id)
    # Remove duplicates just in case eos_token_id is the same as im_end_id
    eos_tokens_ids = list(set(eos_tokens_ids))


    generated_token_ids = model.generate(
        **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9,
        repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_tokens_ids
    )
    new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
    response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()

    full_response = ""
    for char in response_text:
        full_response += char
        time.sleep(0.005) # Adjust for desired speed
        yield full_response


# --- Gradio UI Event Handlers (THESE GET DECORATED) ---
# Note: The @spaces.GPU decorator needs to be imported if not automatically available
# from huggingface_hub import-like syntax or it might be injected.
# For now, let's assume it's magically available in the Space environment.
# If not, you might need to find how to import it for ZeroGPU shared pool.
# It's often available as: `from Fg.spaces import GPU` and used as `@GPU`
# or simply `@spaces.GPU` if `spaces` is an auto-imported object.

# Try without explicit import first, as HF might inject it.
# If "spaces is not defined" error, you'll need to find the correct import for it.

# @spaces.GPU # Placeholder for actual decorator
@gr.्रु # This is a Gradio decorator for functions, not the HF GPU one.
        # We need to find the correct HF spaces GPU decorator.
        # For now, I'll structure as if it exists.
        # The actual execution of model loading and generation will happen here.
        
# It's common to decorate the function called by the Gradio event.
# Let's try decorating the prediction functions.
# If `@spaces.GPU` is not found, the app will error earlier. You might need to find its import from HF docs for ZeroGPU.
# `from hf_spaces_shared_gpu import gpu_heavy_task` is a made-up example.
# Let's assume for now that if the hardware is "ZeroGPU" and this decorator is required,
# the Hugging Face platform makes `spaces.GPU` available.

def base_model_predict_decorated(user_message, chat_history):
    # This function will now be responsible for triggering the load and then generating.
    try:
        # Model loading now happens here, within the GPU-allocated function
        # The generate_chat_response will call load_model_and_tokenizer internally if needed
        bot_response_stream = generate_chat_response(user_message, chat_history, "base")
        full_bot_message = ""
        for chunk in bot_response_stream:
            full_bot_message = chunk
            yield full_bot_message
    except Exception as e:
        print(f"Error in base_model_predict_decorated: {e}")
        yield f"Error generating base model response: {e}"

def ft_model_predict_decorated(user_message, chat_history):
    try:
        # Model loading now happens here
        bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
        full_bot_message = ""
        for chunk in bot_response_stream:
            full_bot_message = chunk
            yield full_bot_message
    except Exception as e:
        print(f"Error in ft_model_predict_decorated: {e}")
        yield f"Error generating fine-tuned response: {e}"

# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        f"""
        # 🎬 CineGuide vs. Base {BASE_MODEL_ID}
        Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`)
        with the base {BASE_MODEL_ID} model.
        Type your movie-related query below and see how each model responds!
        **Note:** Models are loaded on first use and may take some time. Using shared GPU pool.
        """
    )
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}")
            chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, type="messages") # Use type="messages"
        with gr.Column(scale=1):
            gr.Markdown(f"## 🤖 Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})")
            chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, type="messages") # Use type="messages"

    with gr.Row():
        shared_input_textbox = gr.Textbox(
            show_label=False, placeholder="Enter your movie query...", container=False, scale=7
        )
        submit_button = gr.Button("✉️ Send", variant="primary", scale=1)

    gr.Examples(
        examples=[
            "Hi! I'm looking for something funny to watch tonight.",
            "I love dry, witty humor more than slapstick.",
            "I'm really into complex sci-fi movies that make you think.",
        ],
        inputs=[shared_input_textbox], label="Example Prompts"
    )

    # Apply the @spaces.GPU decorator if you find the correct way to import/use it.
    # For now, the functions themselves will handle loading.
    # If the decorator is `@spaces.GPU()`, it would be:
    # submit_button.click(spaces.GPU()(base_model_predict_decorated), ...)
    # This part is tricky without knowing the exact decorator syntax for ZeroGPU.
    # Let's assume the functions are called and *they* handle the GPU context internally.
    # If the platform *requires* the event handler itself to be decorated, that's a different structure.

    # The functions `base_model_predict_decorated` and `ft_model_predict_decorated`
    # are what Gradio will call. If these need the `@spaces.GPU` decorator, you'd apply it like:
    # @spaces.GPU
    # def decorated_base_predict(user_message, chat_history):
    #     yield from base_model_predict_decorated(user_message, chat_history)
    # And then pass `decorated_base_predict` to `submit_button.click`

    # Simpler approach for now: let Gradio call these directly.
    # If a wrapper is needed for the decorator, we can add it.
    submit_button.click(
        base_model_predict_decorated,
        [shared_input_textbox, chatbot_base],
        [chatbot_base],
        # api_name="base_predict" # Optional
    )
    submit_button.click(
        ft_model_predict_decorated,
        [shared_input_textbox, chatbot_ft],
        [chatbot_ft],
        # api_name="ft_predict" # Optional
    )
    # Handle textbox submit event for both
    shared_input_textbox.submit(
        base_model_predict_decorated,
        [shared_input_textbox, chatbot_base],
        [chatbot_base]
    )
    shared_input_textbox.submit(
        ft_model_predict_decorated,
        [shared_input_textbox, chatbot_ft],
        [chatbot_ft]
    )

    def clear_textbox_fn(): return ""
    submit_button.click(clear_textbox_fn, [], [shared_input_textbox])
    shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox])


if __name__ == "__main__":
    # The following line is usually specific to certain Space configurations.
    # For ZeroGPU with @spaces.GPU, this might be needed in the README.md/config.yaml
    # rather than here, or the decorator itself implies it.
    # demo.config(dependencies=["torch", "transformers", "accelerate", ...])
    
    # Check Gradio docs for how to make a function eligible for @spaces.GPU if it's not a direct event handler.
    # Often, the main event handler itself is decorated.
    demo.queue()
    demo.launch(debug=True)