serhany commited on
Commit
0c08550
·
verified ·
1 Parent(s): 78407e7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +267 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
4
+ import time
5
+ import os
6
+
7
+ # --- Configuration ---
8
+ BASE_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
9
+ # NOW, this points to your model on the Hugging Face Hub
10
+ FINETUNED_MODEL_ID = "serhany/cineguide-qwen2.5-7b-instruct-ft"
11
+
12
+ # System prompts (same as before)
13
+ SYSTEM_PROMPT_CINEGUIDE = """You are CineGuide, a knowledgeable and friendly movie recommendation assistant. Your goal is to:
14
+ 1. Provide personalized movie recommendations based on user preferences
15
+ 2. Give brief, compelling rationales for why you recommend each movie
16
+ 3. Ask thoughtful follow-up questions to better understand user tastes
17
+ 4. Maintain an enthusiastic but not overwhelming tone about cinema
18
+
19
+ When recommending movies, always explain WHY the movie fits their preferences."""
20
+ SYSTEM_PROMPT_BASE = "You are a helpful AI assistant."
21
+
22
+ # --- Model Loading ---
23
+ _models_cache = {}
24
+
25
+ def get_model_and_tokenizer(model_id_or_path, is_local_path=False): # Added is_local_path for flexibility
26
+ if model_id_or_path in _models_cache:
27
+ return _models_cache[model_id_or_path]
28
+
29
+ print(f"Loading model: {model_id_or_path}")
30
+ # For models from Hub, trust_remote_code is often needed for custom architectures like Qwen
31
+ # For local paths, it might also be needed if they were saved with trust_remote_code=True
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_id_or_path,
35
+ torch_dtype=torch.bfloat16,
36
+ device_map="auto",
37
+ trust_remote_code=True,
38
+ # attn_implementation="flash_attention_2" # Optional
39
+ )
40
+ model.eval()
41
+
42
+ if tokenizer.pad_token is None:
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+ # Ensure pad_token_id is also set if pad_token is set
45
+ if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
46
+ tokenizer.pad_token_id = tokenizer.eos_token_id
47
+
48
+
49
+ _models_cache[model_id_or_path] = (model, tokenizer)
50
+ print(f"Finished loading: {model_id_or_path}")
51
+ return model, tokenizer
52
+
53
+ print("Pre-loading models...")
54
+ model_base, tokenizer_base = None, None
55
+ model_ft, tokenizer_ft = None, None
56
+
57
+ try:
58
+ model_base, tokenizer_base = get_model_and_tokenizer(BASE_MODEL_ID)
59
+ print("Base model loaded.")
60
+ except Exception as e:
61
+ print(f"Error loading base model ({BASE_MODEL_ID}): {e}")
62
+
63
+ try:
64
+ model_ft, tokenizer_ft = get_model_and_tokenizer(FINETUNED_MODEL_ID)
65
+ print("Fine-tuned model loaded.")
66
+ except Exception as e:
67
+ print(f"Error loading fine-tuned model ({FINETUNED_MODEL_ID}): {e}")
68
+
69
+ print("Model pre-loading complete.")
70
+
71
+ # --- Inference Function (generate_chat_response) ---
72
+ # This function remains largely the same as in the previous app.py.
73
+ # Make sure it uses `model_base, tokenizer_base` and `model_ft, tokenizer_ft` correctly.
74
+ def generate_chat_response(message: str, chat_history: list, model_type: str):
75
+ # ... (Keep the exact same generate_chat_response function from the previous app.py)
76
+ if model_type == "base":
77
+ if model_base is None or tokenizer_base is None:
78
+ yield f"Base model ({BASE_MODEL_ID}) is not available."
79
+ return
80
+ model, tokenizer = model_base, tokenizer_base
81
+ system_prompt = SYSTEM_PROMPT_BASE
82
+ elif model_type == "finetuned":
83
+ if model_ft is None or tokenizer_ft is None:
84
+ yield f"Fine-tuned model ({FINETUNED_MODEL_ID}) is not available."
85
+ return
86
+ model, tokenizer = model_ft, tokenizer_ft
87
+ system_prompt = SYSTEM_PROMPT_CINEGUIDE
88
+ else:
89
+ yield "Invalid model type."
90
+ return
91
+
92
+ conversation = []
93
+ if system_prompt:
94
+ conversation.append({"role": "system", "content": system_prompt})
95
+
96
+ for user_msg, assistant_msg in chat_history:
97
+ if user_msg: # Ensure user_msg is not None
98
+ conversation.append({"role": "user", "content": user_msg})
99
+ if assistant_msg: # Ensure assistant_msg is not None
100
+ conversation.append({"role": "assistant", "content": assistant_msg})
101
+ conversation.append({"role": "user", "content": message})
102
+
103
+ prompt = tokenizer.apply_chat_template(
104
+ conversation,
105
+ tokenize=False,
106
+ add_generation_prompt=True
107
+ )
108
+
109
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1800).to(model.device)
110
+
111
+ full_response = ""
112
+ # Make sure eos_token_id is a list if multiple EOS tokens are possible
113
+ eos_tokens_ids = [tokenizer.eos_token_id]
114
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
115
+ if im_end_id != tokenizer.unk_token_id: # Check if <|im_end|> is in vocab
116
+ eos_tokens_ids.append(im_end_id)
117
+
118
+
119
+ generated_token_ids = model.generate(
120
+ **inputs,
121
+ max_new_tokens=512,
122
+ do_sample=True,
123
+ temperature=0.7,
124
+ top_p=0.9,
125
+ repetition_penalty=1.1,
126
+ pad_token_id=tokenizer.pad_token_id, # Use pad_token_id
127
+ eos_token_id=eos_tokens_ids
128
+ )
129
+
130
+ new_tokens = generated_token_ids[0, inputs['input_ids'].shape[1]:]
131
+ response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
132
+ response_text = response_text.replace("<|im_end|>", "").strip()
133
+
134
+ for char in response_text:
135
+ full_response += char
136
+ time.sleep(0.005)
137
+ yield full_response
138
+
139
+ def respond_base(message, chat_history):
140
+ yield from generate_chat_response(message, chat_history, "base")
141
+
142
+ def respond_finetuned(message, chat_history):
143
+ yield from generate_chat_response(message, chat_history, "finetuned")
144
+
145
+
146
+ # --- Gradio UI (with gr.Blocks as demo:) ---
147
+ # This part remains largely the same as the previous app.py
148
+ # Ensure the Markdown and labels correctly reference the models being loaded.
149
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
150
+ gr.Markdown(
151
+ f"""
152
+ # 🎬 CineGuide vs. Base {BASE_MODEL_ID}
153
+ Compare the fine-tuned CineGuide movie recommender (loaded from `{FINETUNED_MODEL_ID}`)
154
+ with the base {BASE_MODEL_ID} model.
155
+ Type your movie-related query below and see how each model responds!
156
+ """
157
+ )
158
+ # ... (Rest of the UI definition: Rows, Columns, Chatbots, Textbox, Button, Examples)
159
+ with gr.Row():
160
+ with gr.Column(scale=1):
161
+ gr.Markdown(f"## 🗣️ Base {BASE_MODEL_ID}")
162
+ chatbot_base = gr.Chatbot(label="Base Model Chat", height=500, bubble_full_width=False)
163
+ if model_base is None:
164
+ gr.Markdown(f"⚠️ Base model ({BASE_MODEL_ID}) could not be loaded.")
165
+
166
+ with gr.Column(scale=1):
167
+ gr.Markdown(f"## 🤖 Fine-tuned CineGuide (from {FINETUNED_MODEL_ID})")
168
+ chatbot_ft = gr.Chatbot(label="CineGuide Chat", height=500, bubble_full_width=False)
169
+ if model_ft is None:
170
+ gr.Markdown(f"⚠️ Fine-tuned model ({FINETUNED_MODEL_ID}) could not be loaded.")
171
+
172
+ with gr.Row():
173
+ shared_input_textbox = gr.Textbox(
174
+ show_label=False,
175
+ placeholder="Enter your movie query here and press Enter...",
176
+ container=False,
177
+ scale=7,
178
+ )
179
+ submit_button = gr.Button("✉️ Send", variant="primary", scale=1)
180
+
181
+ gr.Examples(
182
+ examples=[
183
+ "Hi! I'm looking for something funny to watch tonight.",
184
+ "I love dry, witty humor more than slapstick. Think more British comedy style.",
185
+ "I'm really into complex sci-fi movies that make you think. I loved Arrival and Blade Runner 2049.",
186
+ "I need help planning a family movie night. We have kids aged 8, 11, and 14, plus adults.",
187
+ "I'm going through a tough breakup and need something uplifting but not cheesy romantic.",
188
+ "I loved Parasite and want to explore more international cinema. Where should I start?",
189
+ ],
190
+ inputs=[shared_input_textbox],
191
+ label="Example Prompts (click to use)"
192
+ )
193
+
194
+ def base_model_predict(user_message, chat_history):
195
+ if model_base is None: # Add this check
196
+ chat_history.append((user_message, f"Base model ({BASE_MODEL_ID}) is not available."))
197
+ yield chat_history
198
+ return
199
+
200
+ chat_history.append((user_message, ""))
201
+ for response_chunk in respond_base(user_message, chat_history[:-1]):
202
+ chat_history[-1] = (user_message, response_chunk)
203
+ yield chat_history
204
+
205
+ def ft_model_predict(user_message, chat_history):
206
+ if model_ft is None: # Add this check
207
+ chat_history.append((user_message, f"Fine-tuned model ({FINETUNED_MODEL_ID}) is not available."))
208
+ yield chat_history
209
+ return
210
+
211
+ chat_history.append((user_message, ""))
212
+ for response_chunk in respond_finetuned(user_message, chat_history[:-1]):
213
+ chat_history[-1] = (user_message, response_chunk)
214
+ yield chat_history
215
+
216
+ # Event handlers
217
+ actions = []
218
+ if model_base is not None:
219
+ actions.append(
220
+ shared_input_textbox.submit(
221
+ base_model_predict,
222
+ [shared_input_textbox, chatbot_base],
223
+ [chatbot_base],
224
+ queue=True
225
+ )
226
+ )
227
+ actions.append(
228
+ submit_button.click(
229
+ base_model_predict,
230
+ [shared_input_textbox, chatbot_base],
231
+ [chatbot_base],
232
+ queue=True
233
+ )
234
+ )
235
+
236
+ if model_ft is not None:
237
+ actions.append(
238
+ shared_input_textbox.submit(
239
+ ft_model_predict,
240
+ [shared_input_textbox, chatbot_ft],
241
+ [chatbot_ft],
242
+ queue=True
243
+ )
244
+ )
245
+ actions.append(
246
+ submit_button.click(
247
+ ft_model_predict,
248
+ [shared_input_textbox, chatbot_ft],
249
+ [chatbot_ft],
250
+ queue=True
251
+ )
252
+ )
253
+
254
+ # Clear textbox after all submits are queued. This is slightly simplified.
255
+ # For a more robust clear, you might need to chain these events or use gr.Group.
256
+ def clear_textbox_fn():
257
+ return ""
258
+
259
+ if actions: # If any model is active
260
+ shared_input_textbox.submit(clear_textbox_fn, [], [shared_input_textbox])
261
+ submit_button.click(clear_textbox_fn, [], [shared_input_textbox])
262
+
263
+
264
+ # --- Launch the App ---
265
+ if __name__ == "__main__":
266
+ demo.queue()
267
+ demo.launch(debug=True) # share=True for public link if running locally
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.7.1+cu118
2
+ transformers
3
+ gradio
4
+ accelerate
5
+ datasets
6
+ peft
7
+ trl
8
+ scikit-learn
9
+ einops
10
+ sentencepiece