sagar007 commited on
Commit
63d3bc6
·
verified ·
1 Parent(s): f87dcd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -213
app.py CHANGED
@@ -10,10 +10,10 @@ import numpy
10
  logging.basicConfig(level=logging.INFO)
11
 
12
  class LLaVAPhiModel:
13
- def __init__(self, model_id="microsoft/phi-1_5"): # Updated to match config
14
  self.device = "cuda"
15
  self.model_id = model_id
16
- logging.info(f"Initializing LLaVA-Phi model with {model_id}...")
17
 
18
  # Initialize tokenizer
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -21,37 +21,28 @@ class LLaVAPhiModel:
21
  self.tokenizer.pad_token = self.tokenizer.eos_token
22
 
23
  try:
24
- # Use CLIPProcessor with the correct model name from config
25
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
  logging.info("Successfully loaded CLIP processor")
27
  except Exception as e:
28
  logging.error(f"Failed to load CLIP processor: {str(e)}")
29
  self.processor = None
30
 
31
- # Increase history length to retain more context
32
  self.history = []
33
  self.model = None
34
  self.clip = None
35
-
36
- # Default generation parameters - can be updated from config
37
- self.temperature = 0.3
38
- self.top_p = 0.92
39
- self.top_k = 50
40
- self.repetition_penalty = 1.2
41
-
42
- # Set max length from config
43
- self.max_length = 512 # Default value, will be updated from config
44
 
45
  @spaces.GPU
46
  def ensure_models_loaded(self):
47
  """Ensure models are loaded in GPU context"""
48
  if self.model is None:
49
- # Use 4-bit quantization according to config
50
  from transformers import BitsAndBytesConfig
51
  quantization_config = BitsAndBytesConfig(
52
- load_in_4bit=True, # Changed to match config
53
- bnb_4bit_compute_dtype=torch.bfloat16, # Changed to bfloat16 to match config's mixed_precision
54
- bnb_4bit_use_double_quant=False
 
55
  )
56
 
57
  try:
@@ -63,156 +54,132 @@ class LLaVAPhiModel:
63
  trust_remote_code=True
64
  )
65
  self.model.config.pad_token_id = self.tokenizer.eos_token_id
66
- logging.info(f"Successfully loaded main model: {self.model_id}")
67
  except Exception as e:
68
  logging.error(f"Failed to load main model: {str(e)}")
69
  raise
70
 
71
  if self.clip is None:
72
  try:
73
- # Load CLIP model from config
74
- clip_model_name = "openai/clip-vit-base-patch32" # From config
75
- self.clip = CLIPModel.from_pretrained(clip_model_name).to(self.device)
76
- logging.info(f"Successfully loaded CLIP model: {clip_model_name}")
77
  except Exception as e:
78
  logging.error(f"Failed to load CLIP model: {str(e)}")
79
  self.clip = None
80
 
81
- def apply_lora_config(self, lora_params):
82
- """Apply LoRA configuration to the model - to be called during training"""
83
- from peft import LoraConfig, get_peft_model
84
-
85
- lora_config = LoraConfig(
86
- r=lora_params.get("r", 16),
87
- lora_alpha=lora_params.get("lora_alpha", 32),
88
- lora_dropout=lora_params.get("lora_dropout", 0.05),
89
- target_modules=lora_params.get("target_modules", ["Wqkv", "out_proj"]),
90
- bias="none",
91
- task_type="CAUSAL_LM"
92
- )
93
-
94
- # Convert model to PEFT/LoRA model
95
- self.model = get_peft_model(self.model, lora_config)
96
- logging.info("Applied LoRA configuration to the model")
97
- return self.model
98
-
99
- @spaces.GPU(duration=120)
100
- def generate_response(self, message, image=None):
101
  try:
102
  self.ensure_models_loaded()
103
 
104
- # Prepare prompt based on whether we have an image
105
- has_image = image is not None
 
106
 
107
- # Process text input
108
- if has_image:
109
- # For image+text input
110
- prompt = f"human: <image>\n{message}\ngpt:"
111
-
112
- # Check if model has vision encoding capability
113
- if not hasattr(self.model, "encode_image") and not hasattr(self.model, "get_vision_tower"):
114
- logging.warning("Model doesn't have standard image encoding methods")
115
- has_image = False
116
- prompt = f"human: {message}\ngpt:"
117
- else:
118
- # For text-only input
119
- prompt = f"human: {message}\ngpt:"
120
 
121
- # Include previous conversation context
122
- context = ""
123
- for turn in self.history[-5:]: # Include 5 previous turns
124
- context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
125
-
126
- full_prompt = context + prompt
127
-
128
- # Tokenize the input text
129
- inputs = self.tokenizer(
130
- full_prompt,
131
- return_tensors="pt",
132
- padding=True,
133
- truncation=True,
134
- max_length=self.max_length
135
- )
136
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
137
 
138
- # LLaVA-Phi specific image handling
139
- if has_image:
140
  try:
141
- # Convert image to correct format
142
- if isinstance(image, str):
143
- image = Image.open(image)
144
- elif isinstance(image, numpy.ndarray):
145
- image = Image.fromarray(image)
146
-
147
- # Ensure image is in RGB mode
148
- if image.mode != 'RGB':
149
- image = image.convert('RGB')
150
-
151
- # Process the image with CLIP processor
152
  image_inputs = self.processor(images=image, return_tensors="pt")
153
  image_features = self.clip.get_image_features(
154
  pixel_values=image_inputs.pixel_values.to(self.device)
155
  )
156
-
157
- # Some LLaVA models have a prepare_inputs_for_generation method
158
- if hasattr(self.model, "prepare_inputs_for_generation"):
159
- logging.info("Using model's prepare_inputs_for_generation for image handling")
160
-
161
- # Generate with image context
162
- with torch.no_grad():
163
- outputs = self.model.generate(
164
- **inputs,
165
- max_new_tokens=256,
166
- min_length=20,
167
- temperature=self.temperature,
168
- do_sample=True,
169
- top_p=self.top_p,
170
- top_k=self.top_k,
171
- repetition_penalty=self.repetition_penalty,
172
- no_repeat_ngram_size=3,
173
- use_cache=True,
174
- pad_token_id=self.tokenizer.pad_token_id,
175
- eos_token_id=self.tokenizer.eos_token_id
176
- )
177
-
178
  except Exception as e:
179
- logging.error(f"Error handling image: {str(e)}")
180
- # Fall back to text-only generation
181
- logging.info("Falling back to text-only generation")
182
- with torch.no_grad():
183
- outputs = self.model.generate(
184
- **inputs,
185
- max_new_tokens=256,
186
- min_length=20,
187
- temperature=self.temperature,
188
- do_sample=True,
189
- top_p=self.top_p,
190
- top_k=self.top_k,
191
- repetition_penalty=self.repetition_penalty,
192
- no_repeat_ngram_size=3,
193
- use_cache=True,
194
- pad_token_id=self.tokenizer.pad_token_id,
195
- eos_token_id=self.tokenizer.eos_token_id
196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  else:
198
- # Text-only generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  with torch.no_grad():
200
  outputs = self.model.generate(
201
  **inputs,
202
- max_new_tokens=200,
203
  min_length=20,
204
- temperature=self.temperature,
205
  do_sample=True,
206
- top_p=self.top_p,
207
- top_k=self.top_k,
208
- repetition_penalty=self.repetition_penalty,
209
  no_repeat_ngram_size=4,
210
  use_cache=True,
211
  pad_token_id=self.tokenizer.pad_token_id,
212
  eos_token_id=self.tokenizer.eos_token_id
213
  )
214
 
215
- # Decode and clean up the response
216
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
217
 
218
  # Clean up response
@@ -235,40 +202,14 @@ class LLaVAPhiModel:
235
  self.history = []
236
  return None
237
 
238
- # Add new function to control generation parameters
239
- def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
240
- """Update generation parameters to control hallucination tendency"""
241
- self.temperature = temperature
242
- self.top_p = top_p
243
- self.top_k = top_k
244
- self.repetition_penalty = repetition_penalty
245
- return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
246
-
247
- # New method to apply config file settings
248
- def apply_config(self, config):
249
- """Apply settings from config file"""
250
- model_params = config.get("model_params", {})
251
- self.model_id = model_params.get("model_name", self.model_id)
252
- self.max_length = model_params.get("max_length", 512)
253
-
254
- # Update generation parameters if needed
255
- training_params = config.get("training_params", {})
256
- # Could add specific updates based on training_params if needed
257
-
258
- return f"Applied configuration. Model: {self.model_id}, Max Length: {self.max_length}"
259
-
260
- def create_demo(config=None):
261
  try:
262
- # Initialize with config file settings
263
  model = LLaVAPhiModel()
264
 
265
- if config:
266
- model.apply_config(config)
267
-
268
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
269
  gr.Markdown(
270
  """
271
- # LLaVA-Phi Demo (Optimized for Accuracy)
272
  Chat with a vision-language model that can understand both text and images.
273
  """
274
  )
@@ -288,66 +229,34 @@ def create_demo(config=None):
288
 
289
  image = gr.Image(type="pil", label="Upload Image (Optional)")
290
 
291
- # Add generation parameter controls
292
- with gr.Accordion("Advanced Settings (Reduce Hallucinations)", open=False):
293
- gr.Markdown("Adjust these parameters to control hallucination tendency")
294
- temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
295
- top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
296
- top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
297
- rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
298
- update_params = gr.Button("Update Parameters")
299
-
300
- # Add debugging information box
301
- debug_info = gr.Textbox(label="Debug Info", interactive=False)
302
-
303
- # Add config information
304
- if config:
305
- config_info = f"Model: {model.model_id}, Max Length: {model.max_length}"
306
- gr.Markdown(f"**Current Configuration:** {config_info}")
307
-
308
  def respond(message, chat_history, image):
309
  if not message and image is None:
310
- return chat_history, ""
311
 
312
- try:
313
- response = model.generate_response(message, image)
314
- chat_history.append((message, response))
315
- debug_msg = "Response generated successfully"
316
- return "", chat_history, debug_msg
317
- except Exception as e:
318
- debug_msg = f"Error: {str(e)}"
319
- return message, chat_history, debug_msg
320
 
321
  def clear_chat():
322
  model.clear_history()
323
- return None, None, "Chat history cleared"
324
-
325
- def update_params_fn(temp, top_p, top_k, rep_penalty):
326
- result = model.update_generation_params(temp, top_p, top_k, rep_penalty)
327
- return f"Parameters updated: temp={temp}, top_p={top_p}, top_k={top_k}, rep_penalty={rep_penalty}"
328
 
329
  submit.click(
330
  respond,
331
  [msg, chatbot, image],
332
- [msg, chatbot, debug_info],
333
  )
334
 
335
  clear.click(
336
  clear_chat,
337
  None,
338
- [chatbot, image, debug_info],
339
  )
340
 
341
  msg.submit(
342
  respond,
343
  [msg, chatbot, image],
344
- [msg, chatbot, debug_info],
345
- )
346
-
347
- update_params.click(
348
- update_params_fn,
349
- [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
350
- [debug_info]
351
  )
352
 
353
  return demo
@@ -356,18 +265,7 @@ def create_demo(config=None):
356
  raise
357
 
358
  if __name__ == "__main__":
359
- # Load config file
360
- import json
361
-
362
- try:
363
- with open("config.json", "r") as f:
364
- config = json.load(f)
365
- logging.info("Successfully loaded config file")
366
- except Exception as e:
367
- logging.error(f"Error loading config: {str(e)}")
368
- config = None
369
-
370
- demo = create_demo(config)
371
  demo.launch(
372
  server_name="0.0.0.0",
373
  server_port=7860,
 
10
  logging.basicConfig(level=logging.INFO)
11
 
12
  class LLaVAPhiModel:
13
+ def __init__(self, model_id="sagar007/Lava_phi"):
14
  self.device = "cuda"
15
  self.model_id = model_id
16
+ logging.info("Initializing LLaVA-Phi model...")
17
 
18
  # Initialize tokenizer
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
 
21
  self.tokenizer.pad_token = self.tokenizer.eos_token
22
 
23
  try:
24
+ # Use CLIPProcessor directly instead of AutoProcessor
25
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
  logging.info("Successfully loaded CLIP processor")
27
  except Exception as e:
28
  logging.error(f"Failed to load CLIP processor: {str(e)}")
29
  self.processor = None
30
 
 
31
  self.history = []
32
  self.model = None
33
  self.clip = None
 
 
 
 
 
 
 
 
 
34
 
35
  @spaces.GPU
36
  def ensure_models_loaded(self):
37
  """Ensure models are loaded in GPU context"""
38
  if self.model is None:
39
+ # Load main model with updated quantization config
40
  from transformers import BitsAndBytesConfig
41
  quantization_config = BitsAndBytesConfig(
42
+ load_in_4bit=True,
43
+ bnb_4bit_compute_dtype=torch.float16,
44
+ bnb_4bit_use_double_quant=True,
45
+ bnb_4bit_quant_type="nf4"
46
  )
47
 
48
  try:
 
54
  trust_remote_code=True
55
  )
56
  self.model.config.pad_token_id = self.tokenizer.eos_token_id
57
+ logging.info("Successfully loaded main model")
58
  except Exception as e:
59
  logging.error(f"Failed to load main model: {str(e)}")
60
  raise
61
 
62
  if self.clip is None:
63
  try:
64
+ # Use CLIPModel directly instead of AutoModel
65
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
66
+ logging.info("Successfully loaded CLIP model")
 
67
  except Exception as e:
68
  logging.error(f"Failed to load CLIP model: {str(e)}")
69
  self.clip = None
70
 
71
+ @spaces.GPU
72
+ def process_image(self, image):
73
+ """Process image through CLIP if available"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
  self.ensure_models_loaded()
76
 
77
+ if self.clip is None or self.processor is None:
78
+ logging.warning("CLIP model or processor not available")
79
+ return None
80
 
81
+ # Convert image to correct format
82
+ if isinstance(image, str):
83
+ image = Image.open(image)
84
+ elif isinstance(image, numpy.ndarray):
85
+ image = Image.fromarray(image)
 
 
 
 
 
 
 
 
86
 
87
+ # Ensure image is in RGB mode
88
+ if image.mode != 'RGB':
89
+ image = image.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ with torch.no_grad():
 
92
  try:
93
+ # Process image with error handling
 
 
 
 
 
 
 
 
 
 
94
  image_inputs = self.processor(images=image, return_tensors="pt")
95
  image_features = self.clip.get_image_features(
96
  pixel_values=image_inputs.pixel_values.to(self.device)
97
  )
98
+ logging.info("Successfully processed image through CLIP")
99
+ return image_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
+ logging.error(f"Error during image processing: {str(e)}")
102
+ return None
103
+ except Exception as e:
104
+ logging.error(f"Error in process_image: {str(e)}")
105
+ return None
106
+
107
+ @spaces.GPU(duration=120)
108
+ def generate_response(self, message, image=None):
109
+ try:
110
+ self.ensure_models_loaded()
111
+
112
+ if image is not None:
113
+ image_features = self.process_image(image)
114
+ has_image = image_features is not None
115
+ if not has_image:
116
+ message = "Note: Image processing is not available - continuing with text only.\n" + message
117
+
118
+ prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
119
+ context = ""
120
+ for turn in self.history[-3:]:
121
+ context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
122
+
123
+ full_prompt = context + prompt
124
+ inputs = self.tokenizer(
125
+ full_prompt,
126
+ return_tensors="pt",
127
+ padding=True,
128
+ truncation=True,
129
+ max_length=512
130
+ )
131
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
132
+
133
+ if has_image:
134
+ inputs["image_features"] = image_features
135
+
136
+ with torch.no_grad():
137
+ outputs = self.model.generate(
138
+ **inputs,
139
+ max_new_tokens=256,
140
+ min_length=20,
141
+ temperature=0.7,
142
+ do_sample=True,
143
+ top_p=0.9,
144
+ top_k=40,
145
+ repetition_penalty=1.5,
146
+ no_repeat_ngram_size=3,
147
+ use_cache=True,
148
+ pad_token_id=self.tokenizer.pad_token_id,
149
+ eos_token_id=self.tokenizer.eos_token_id
150
+ )
151
  else:
152
+ prompt = f"human: {message}\ngpt:"
153
+ context = ""
154
+ for turn in self.history[-3:]:
155
+ context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
156
+
157
+ full_prompt = context + prompt
158
+ inputs = self.tokenizer(
159
+ full_prompt,
160
+ return_tensors="pt",
161
+ padding=True,
162
+ truncation=True,
163
+ max_length=512
164
+ )
165
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
166
+
167
  with torch.no_grad():
168
  outputs = self.model.generate(
169
  **inputs,
170
+ max_new_tokens=150,
171
  min_length=20,
172
+ temperature=0.6,
173
  do_sample=True,
174
+ top_p=0.85,
175
+ top_k=30,
176
+ repetition_penalty=1.8,
177
  no_repeat_ngram_size=4,
178
  use_cache=True,
179
  pad_token_id=self.tokenizer.pad_token_id,
180
  eos_token_id=self.tokenizer.eos_token_id
181
  )
182
 
 
183
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
184
 
185
  # Clean up response
 
202
  self.history = []
203
  return None
204
 
205
+ def create_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  try:
 
207
  model = LLaVAPhiModel()
208
 
 
 
 
209
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
210
  gr.Markdown(
211
  """
212
+ # LLaVA-Phi Demo (ZeroGPU)
213
  Chat with a vision-language model that can understand both text and images.
214
  """
215
  )
 
229
 
230
  image = gr.Image(type="pil", label="Upload Image (Optional)")
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def respond(message, chat_history, image):
233
  if not message and image is None:
234
+ return chat_history
235
 
236
+ response = model.generate_response(message, image)
237
+ chat_history.append((message, response))
238
+ return "", chat_history
 
 
 
 
 
239
 
240
  def clear_chat():
241
  model.clear_history()
242
+ return None, None
 
 
 
 
243
 
244
  submit.click(
245
  respond,
246
  [msg, chatbot, image],
247
+ [msg, chatbot],
248
  )
249
 
250
  clear.click(
251
  clear_chat,
252
  None,
253
+ [chatbot, image],
254
  )
255
 
256
  msg.submit(
257
  respond,
258
  [msg, chatbot, image],
259
+ [msg, chatbot],
 
 
 
 
 
 
260
  )
261
 
262
  return demo
 
265
  raise
266
 
267
  if __name__ == "__main__":
268
+ demo = create_demo()
 
 
 
 
 
 
 
 
 
 
 
269
  demo.launch(
270
  server_name="0.0.0.0",
271
  server_port=7860,