File size: 10,023 Bytes
0c08550
 
6935641
0c08550
6935641
0c08550
 
 
6935641
0c08550
 
 
 
 
 
6935641
0c08550
 
 
6935641
 
 
 
 
 
18449fc
 
6935641
 
 
 
 
18449fc
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
6935641
 
18449fc
0c08550
f1ea8a0
6935641
 
 
 
 
f1ea8a0
0c08550
 
6935641
18449fc
0c08550
18449fc
6935641
0c08550
 
6935641
0c08550
6935641
0c08550
6935641
 
f1ea8a0
6935641
 
 
 
 
 
 
 
f1ea8a0
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18449fc
6935641
 
 
 
 
 
 
18449fc
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1ea8a0
6935641
 
 
 
 
 
 
18449fc
 
6935641
0c08550
 
6935641
 
 
 
 
 
 
 
 
0c08550
 
6935641
0c08550
 
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
 
6935641
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import spaces

# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"

SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema

When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."

# --- Global Model Cache ---
_models_cache = {
    "base": None,
    "finetuned": None,
    "tokenizer_base": None,
    "tokenizer_ft": None,
}

def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
    """Loads a model and tokenizer if not already in cache."""
    if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
        print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
        return _models_cache[model_key], _models_cache[tokenizer_key]

    print(f"Loading {model_key} model ({model_identifier})...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_identifier,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
        model.eval()

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
        
        _models_cache[model_key] = model
        _models_cache[tokenizer_key] = tokenizer
        print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
        return model, tokenizer
    except Exception as e:
        print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
        _models_cache[model_key] = "error"
        _models_cache[tokenizer_key] = "error"
        raise

def generate_chat_response(message: str, chat_history: list, model_type_to_load: str):
    """Generate response using specified model type."""
    model, tokenizer = None, None
    system_prompt = ""

    if model_type_to_load == "base":
        if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
            yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
        system_prompt = SYSTEM_PROMPT_BASE
    elif model_type_to_load == "finetuned":
        if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
            print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}")
            yield "Error: Fine-tuned model ID is not configured correctly."
            return
        if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
            yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
        system_prompt = SYSTEM_PROMPT_CINEGUIDE
    else:
        yield "Invalid model type."
        return

    if model is None or tokenizer is None:
        yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
        return

    # Prepare conversation
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    
    # Add chat history
    conversation.extend(chat_history)
    conversation.append({"role": "user", "content": message})

    # Generate response
    prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
    
    # Prepare EOS tokens
    eos_tokens_ids = [tokenizer.eos_token_id]
    im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    if im_end_id != getattr(tokenizer, 'unk_token_id', None):
        eos_tokens_ids.append(im_end_id)
    eos_tokens_ids = list(set(eos_tokens_ids))

    # Generate
    with torch.no_grad():
        generated_token_ids = model.generate(
            **inputs, 
            max_new_tokens=512, 
            do_sample=True, 
            temperature=0.7, 
            top_p=0.9,
            repetition_penalty=1.1, 
            pad_token_id=tokenizer.pad_token_id, 
            eos_token_id=eos_tokens_ids
        )
    
    new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
    response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()

    # Stream the response
    full_response = ""
    for char in response_text:
        full_response += char
        time.sleep(0.005)
        yield full_response

@spaces.GPU
def base_model_predict(user_message, chat_history):
    """Predict using base model - decorated with @spaces.GPU."""
    try:
        bot_response_stream = generate_chat_response(user_message, chat_history, "base")
        for chunk in bot_response_stream:
            yield chunk
    except Exception as e:
        print(f"Error in base_model_predict: {e}")
        yield f"Error generating base model response: {str(e)}"

@spaces.GPU
def ft_model_predict(user_message, chat_history):
    """Predict using fine-tuned model - decorated with @spaces.GPU."""
    try:
        bot_response_stream = generate_chat_response(user_message, chat_history, "finetuned")
        for chunk in bot_response_stream:
            yield chunk
    except Exception as e:
        print(f"Error in ft_model_predict: {e}")
        yield f"Error generating fine-tuned response: {str(e)}"

def format_chat_history(history, message):
    """Format the chat history for the models."""
    formatted_history = []
    for chat in history:
        if isinstance(chat, dict) and 'role' in chat:
            formatted_history.append(chat)
        elif isinstance(chat, list) and len(chat) == 2:
            formatted_history.extend([
                {"role": "user", "content": chat[0]},
                {"role": "assistant", "content": chat[1]}
            ])
    return formatted_history

def respond_base(message, history):
    """Handle base model response for Gradio ChatInterface."""
    formatted_history = format_chat_history(history, message)
    response_gen = base_model_predict(message, formatted_history)
    
    for response in response_gen:
        yield response

def respond_ft(message, history):
    """Handle fine-tuned model response for Gradio ChatInterface."""
    formatted_history = format_chat_history(history, message)
    response_gen = ft_model_predict(message, formatted_history)
    
    for response in response_gen:
        yield response

# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo:
    gr.Markdown(
        f"""
        # 🎬 CineGuide vs. Base Model Comparison
        Compare the fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model.
        
        **Base Model:** `{BASE_MODEL_ID}`  
        **Fine-tuned Model:** `{FINETUNED_MODEL_ID}`
        
        Type your movie-related query below and see how each model responds!
        
        ⚠️ **Note:** Models are loaded on first use and may take 30-60 seconds initially.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"## 🗣️ Base Model")
            gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*")
            chatbot_base = gr.ChatInterface(
                respond_base,
                textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7),
                title="",
                description="",
                theme="soft",
                examples=[
                    "Hi! I'm looking for something funny to watch tonight.",
                    "I love dry, witty humor more than slapstick.",
                    "I'm really into complex sci-fi movies that make you think.",
                    "Can you recommend a good thriller?",
                    "What's a good romantic comedy from the 2000s?"
                ],
                cache_examples=False,
                retry_btn=None,
                undo_btn="⤴️ Undo",
                clear_btn="🗑️ Clear"
            )
            
        with gr.Column(scale=1):
            gr.Markdown(f"## 🎬 CineGuide (Fine-tuned)")
            gr.Markdown(f"*Specialized for movie recommendations*")
            chatbot_ft = gr.ChatInterface(
                respond_ft,
                textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7),
                title="",
                description="",
                theme="soft",
                examples=[
                    "Hi! I'm looking for something funny to watch tonight.",
                    "I love dry, witty humor more than slapstick.",
                    "I'm really into complex sci-fi movies that make you think.",
                    "Can you recommend a good thriller?",
                    "What's a good romantic comedy from the 2000s?"
                ],
                cache_examples=False,
                retry_btn=None,
                undo_btn="⤴️ Undo",
                clear_btn="🗑️ Clear"
            )

if __name__ == "__main__":
    demo.queue(max_size=20)
    demo.launch()