File size: 9,907 Bytes
0c08550
 
6935641
0c08550
6935641
0c08550
 
 
6935641
0c08550
 
 
 
 
 
6935641
0c08550
 
 
6935641
 
 
 
 
 
18449fc
 
6935641
 
 
 
 
18449fc
6935641
 
fcf00e5
 
 
 
 
6935641
 
 
 
 
fcf00e5
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
fcf00e5
 
 
6935641
 
 
 
235bd9f
 
 
 
 
 
 
 
 
 
 
 
 
 
6935641
 
 
 
 
 
 
0c08550
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c08550
6935641
 
18449fc
0c08550
f1ea8a0
6935641
 
 
 
235bd9f
 
 
0c08550
 
6935641
235bd9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18449fc
235bd9f
 
6935641
 
 
235bd9f
 
 
 
 
 
 
f1ea8a0
6935641
 
235bd9f
 
 
 
 
 
 
18449fc
 
6935641
0c08550
 
6935641
fcf00e5
6935641
fcf00e5
 
6935641
fcf00e5
6935641
 
0c08550
 
6935641
0c08550
 
6935641
 
 
 
 
 
 
 
 
 
 
 
 
 
11a5899
6935641
 
0c08550
6935641
fcf00e5
6935641
 
 
 
 
 
 
 
 
 
 
 
11a5899
6935641
0c08550
 
6935641
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import spaces

# --- Configuration ---
BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"

SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
1. Provide personalized movie recommendations based on user preferences
2. Give brief, compelling rationales for why you recommend each movie
3. Ask thoughtful follow-up questions to better understand user tastes
4. Maintain an enthusiastic but not overwhelming tone about cinema

When recommending movies, always explain WHY the movie fits their preferences."""
SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."

# --- Global Model Cache ---
_models_cache = {
    "base": None,
    "finetuned": None,
    "tokenizer_base": None,
    "tokenizer_ft": None,
}

def load_model_and_tokenizer(model_identifier: str, model_key: str, tokenizer_key: str):
    """Loads a model and tokenizer if not already in cache."""
    if _models_cache[model_key] is not None and _models_cache[tokenizer_key] is not None:
        print(f"Using cached {model_key} model and {tokenizer_key} tokenizer.")
        return _models_cache[model_key], _models_cache[tokenizer_key]

    print(f"Loading {model_key} model ({model_identifier})...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_identifier, 
            trust_remote_code=True,
            use_auth_token=False  # Ensure we're not using auth for public models
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_identifier,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
            use_auth_token=False  # Ensure we're not using auth for public models
        )
        model.eval()

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
        
        _models_cache[model_key] = model
        _models_cache[tokenizer_key] = tokenizer
        print(f"Finished loading and cached {model_key} and {tokenizer_key}.")
        return model, tokenizer
    except Exception as e:
        print(f"ERROR loading {model_key} model ({model_identifier}): {e}")
        print(f"Error type: {type(e).__name__}")
        if "404" in str(e) or "not found" in str(e).lower():
            print(f"Model {model_identifier} not found. Please check the model ID.")
        _models_cache[model_key] = "error"
        _models_cache[tokenizer_key] = "error"
        raise

def convert_gradio_history_to_messages(history):
    """Convert Gradio ChatInterface history format to messages format."""
    messages = []
    for exchange in history:
        if isinstance(exchange, (list, tuple)) and len(exchange) == 2:
            user_msg, assistant_msg = exchange
            if user_msg:  # Only add if not empty
                messages.append({"role": "user", "content": str(user_msg)})
            if assistant_msg:  # Only add if not empty
                messages.append({"role": "assistant", "content": str(assistant_msg)})
    return messages

@spaces.GPU
def generate_chat_response(message: str, history: list, model_type_to_load: str):
    """Generate response using specified model type."""
    model, tokenizer = None, None
    system_prompt = ""

    if model_type_to_load == "base":
        if _models_cache["base"] == "error" or _models_cache["tokenizer_base"] == "error":
            yield f"Base model ({BASE_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(BASE_MODEL_ID, "base", "tokenizer_base")
        system_prompt = SYSTEM_PROMPT_BASE
    elif model_type_to_load == "finetuned":
        if not FINETUNED_MODEL_ID or not isinstance(FINETUNED_MODEL_ID, str):
            print(f"CRITICAL ERROR: FINETUNED_MODEL_ID is invalid: {FINETUNED_MODEL_ID}")
            yield "Error: Fine-tuned model ID is not configured correctly."
            return
        if _models_cache["finetuned"] == "error" or _models_cache["tokenizer_ft"] == "error":
            yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) failed to load previously."
            return
        model, tokenizer = load_model_and_tokenizer(FINETUNED_MODEL_ID, "finetuned", "tokenizer_ft")
        system_prompt = SYSTEM_PROMPT_CINEGUIDE
    else:
        yield "Invalid model type."
        return

    if model is None or tokenizer is None:
        yield f"Model or tokenizer for '{model_type_to_load}' is not available after attempting load."
        return

    # Prepare conversation
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    
    # Convert and add chat history
    formatted_history = convert_gradio_history_to_messages(history)
    conversation.extend(formatted_history)
    conversation.append({"role": "user", "content": message})

    try:
        # Generate response
        prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
        
        # Prepare EOS tokens
        eos_tokens_ids = [tokenizer.eos_token_id]
        im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
        if im_end_id != getattr(tokenizer, 'unk_token_id', None):
            eos_tokens_ids.append(im_end_id)
        eos_tokens_ids = list(set(eos_tokens_ids))

        # Generate
        with torch.no_grad():
            generated_token_ids = model.generate(
                **inputs, 
                max_new_tokens=512, 
                do_sample=True, 
                temperature=0.7, 
                top_p=0.9,
                repetition_penalty=1.1, 
                pad_token_id=tokenizer.pad_token_id, 
                eos_token_id=eos_tokens_ids
            )
        
        new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
        response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip().replace("<|im_end|>", "").strip()

        # Stream the response
        full_response = ""
        for char in response_text:
            full_response += char
            time.sleep(0.005)
            yield full_response
            
    except Exception as e:
        print(f"Error during generation: {e}")
        yield f"Error during text generation: {str(e)}"

def respond_base(message, history):
    """Handle base model response for Gradio ChatInterface."""
    try:
        response_gen = generate_chat_response(message, history, "base")
        for response in response_gen:
            yield response
    except Exception as e:
        print(f"Error in respond_base: {e}")
        yield f"Error: {str(e)}"

def respond_ft(message, history):
    """Handle fine-tuned model response for Gradio ChatInterface."""
    try:
        response_gen = generate_chat_response(message, history, "finetuned")
        for response in response_gen:
            yield response
    except Exception as e:
        print(f"Error in respond_ft: {e}")
        yield f"Error: {str(e)}"

# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 CineGuide Comparison") as demo:
    gr.Markdown(
        f"""
        # 🎬 CineGuide vs. Base Model Comparison
        Compare your fine-tuned CineGuide movie recommender with the base {BASE_MODEL_ID.split('/')[-1]} model.
        
        **Base Model:** `{BASE_MODEL_ID}` (Standard Assistant)
        **Fine-tuned Model:** `{FINETUNED_MODEL_ID}` (CineGuide - Specialized for Movies)
        
        Type your movie-related query below and see how fine-tuning improves movie recommendations!
        
        ⚠️ **Note:** Models are loaded on first use and may take 30-60 seconds initially.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"## πŸ—£οΈ Base Model")
            gr.Markdown(f"*{BASE_MODEL_ID.split('/')[-1]}*")
            chatbot_base = gr.ChatInterface(
                respond_base,
                textbox=gr.Textbox(placeholder="Ask about movies...", container=False, scale=7),
                title="",
                description="",
                examples=[
                    "Hi! I'm looking for something funny to watch tonight.",
                    "I love dry, witty humor more than slapstick.",
                    "I'm really into complex sci-fi movies that make you think.",
                    "Can you recommend a good thriller?",
                    "What's a good romantic comedy from the 2000s?"
                ],
                cache_examples=False
            )
            
        with gr.Column(scale=1):
            gr.Markdown(f"## 🎬 CineGuide (Fine-tuned)")
            gr.Markdown(f"*Specialized movie recommendation model*")
            chatbot_ft = gr.ChatInterface(
                respond_ft,
                textbox=gr.Textbox(placeholder="Ask CineGuide about movies...", container=False, scale=7),
                title="",
                description="",
                examples=[
                    "Hi! I'm looking for something funny to watch tonight.",
                    "I love dry, witty humor more than slapstick.",
                    "I'm really into complex sci-fi movies that make you think.",
                    "Can you recommend a good thriller?",
                    "What's a good romantic comedy from the 2000s?"
                ],
                cache_examples=False
            )

if __name__ == "__main__":
    demo.queue(max_size=20)
    demo.launch()