Chamin09 commited on
Commit
71aaa5d
·
verified ·
1 Parent(s): a40ee8d

Update models/image_models.py

Browse files
Files changed (1) hide show
  1. models/image_models.py +284 -267
models/image_models.py CHANGED
@@ -1,267 +1,284 @@
1
- # models/image_models.py
2
- import logging
3
- import os
4
- import torch
5
- from typing import Dict, List, Optional, Tuple, Union, Any
6
- from PIL import Image
7
- import numpy as np
8
- from transformers import BlipProcessor, BlipForConditionalGeneration
9
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
10
-
11
- class ImageModelManager:
12
- def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
13
- """Initialize the ImageModelManager with optional utilities."""
14
- self.logger = logging.getLogger(__name__)
15
- self.token_manager = token_manager
16
- self.cache_manager = cache_manager
17
- self.metrics_calculator = metrics_calculator
18
-
19
- # Model instances
20
- self.lightweight_model = None
21
- self.lightweight_processor = None
22
- self.advanced_model = None
23
- self.advanced_processor = None
24
-
25
- # Model names
26
- self.lightweight_model_name = "Salesforce/blip-image-captioning-base"
27
- self.advanced_model_name = "Salesforce/blip2-opt-2.7b"
28
-
29
- # Track initialization state
30
- self.initialized = {
31
- "lightweight": False,
32
- "advanced": False
33
- }
34
-
35
- # Default complexity thresholds
36
- self.complexity_thresholds = {
37
- "entropy": 4.5, # Higher entropy suggests more complex image
38
- "edge_density": 0.15, # Higher edge density suggests more details
39
- "size": 500000 # Larger images may contain more information
40
- }
41
-
42
- def initialize_lightweight_model(self):
43
- """Initialize the lightweight image captioning model."""
44
- if self.initialized["lightweight"]:
45
- return
46
-
47
- try:
48
- # Register with token manager if available
49
- if self.token_manager:
50
- self.token_manager.register_model(
51
- self.lightweight_model_name, "image_captioning")
52
-
53
- # Load model and processor
54
- self.logger.info(f"Loading lightweight image model: {self.lightweight_model_name}")
55
- self.lightweight_processor = BlipProcessor.from_pretrained(self.lightweight_model_name)
56
- self.lightweight_model = BlipForConditionalGeneration.from_pretrained(
57
- self.lightweight_model_name, torch_dtype=torch.float32)
58
-
59
- self.initialized["lightweight"] = True
60
- self.logger.info("Lightweight image model initialized successfully")
61
-
62
- except Exception as e:
63
- self.logger.error(f"Failed to initialize lightweight image model: {e}")
64
- raise
65
-
66
- def initialize_advanced_model(self):
67
- """Initialize the advanced image captioning model."""
68
- if self.initialized["advanced"]:
69
- return
70
-
71
- try:
72
- # Register with token manager if available
73
- if self.token_manager:
74
- self.token_manager.register_model(
75
- self.advanced_model_name, "image_captioning")
76
-
77
- # Load model and processor
78
- self.logger.info(f"Loading advanced image model: {self.advanced_model_name}")
79
- self.advanced_processor = Blip2Processor.from_pretrained(self.advanced_model_name)
80
- self.advanced_model = Blip2ForConditionalGeneration.from_pretrained(
81
- self.advanced_model_name, torch_dtype=torch.float16)
82
-
83
- self.initialized["advanced"] = True
84
- self.logger.info("Advanced image model initialized successfully")
85
-
86
- except Exception as e:
87
- self.logger.error(f"Failed to initialize advanced image model: {e}")
88
- raise
89
-
90
- def determine_image_complexity(self, image: Image.Image) -> Dict[str, float]:
91
- """
92
- Determine the complexity of an image to guide model selection.
93
- Returns complexity metrics.
94
- """
95
- # Convert to numpy array
96
- img_array = np.array(image.convert("L")) # Convert to grayscale for analysis
97
-
98
- # Calculate image entropy (measure of randomness/information)
99
- histogram = np.histogram(img_array, bins=256, range=(0, 256))[0]
100
- histogram = histogram / histogram.sum()
101
- non_zero = histogram > 0
102
- entropy = -np.sum(histogram[non_zero] * np.log2(histogram[non_zero]))
103
-
104
- # Calculate edge density using simple gradient method
105
- gradient_x = np.abs(np.diff(img_array, axis=1, prepend=0))
106
- gradient_y = np.abs(np.diff(img_array, axis=0, prepend=0))
107
- gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
108
- edge_density = np.mean(gradient_magnitude > 30) # Threshold for edge detection
109
-
110
- # Get image size in pixels
111
- size = image.width * image.height
112
-
113
- return {
114
- "entropy": float(entropy),
115
- "edge_density": float(edge_density),
116
- "size": size
117
- }
118
-
119
- def select_captioning_model(self, image: Image.Image) -> str:
120
- """
121
- Select the appropriate captioning model based on image complexity.
122
- Returns model type ("lightweight" or "advanced").
123
- """
124
- # Get complexity metrics
125
- complexity = self.determine_image_complexity(image)
126
-
127
- # Decision logic for model selection
128
- use_advanced = (
129
- complexity["entropy"] > self.complexity_thresholds["entropy"] or
130
- complexity["edge_density"] > self.complexity_thresholds["edge_density"] or
131
- complexity["size"] > self.complexity_thresholds["size"]
132
- )
133
-
134
- # Log selection decision
135
- model_type = "advanced" if use_advanced else "lightweight"
136
- self.logger.info(f"Selected {model_type} model for image captioning (complexity: {complexity})")
137
-
138
- # If metrics calculator is available, log model selection
139
- if use_advanced and self.metrics_calculator:
140
- # Estimate energy saved if we had used the advanced model
141
- # This is a negative number since we're using more energy
142
- energy_diff = -0.01 # Approximate difference in watt-hours
143
- self.metrics_calculator.log_model_downgrade(
144
- self.advanced_model_name, self.lightweight_model_name, energy_diff)
145
-
146
- return model_type
147
-
148
- def generate_image_caption(self, image: Union[str, Image.Image],
149
- agent_name: str = "image_processing") -> Dict[str, Any]:
150
- """
151
- Generate caption for an image, selecting appropriate model based on complexity.
152
- Returns caption and metadata.
153
- """
154
- # Handle string input (file path)
155
- if isinstance(image, str):
156
- if os.path.exists(image):
157
- image = Image.open(image).convert('RGB')
158
- else:
159
- raise ValueError(f"Image file not found: {image}")
160
-
161
- # Ensure image is PIL Image
162
- if not isinstance(image, Image.Image):
163
- raise TypeError("Image must be a PIL Image or a valid file path")
164
-
165
- # Check cache if available
166
- image_hash = str(hash(image.tobytes()))
167
- if self.cache_manager:
168
- cache_hit, cached_result = self.cache_manager.get(
169
- image_hash, namespace="image_captions")
170
-
171
- if cache_hit:
172
- # Update metrics if available
173
- if self.metrics_calculator:
174
- self.metrics_calculator.update_cache_metrics(1, 0, 0.01) # Estimated energy saving
175
- return cached_result
176
-
177
- # Select model based on image complexity
178
- model_type = self.select_captioning_model(image)
179
-
180
- # Initialize selected model if needed
181
- if model_type == "advanced":
182
- if not self.initialized["advanced"]:
183
- self.initialize_advanced_model()
184
-
185
- processor = self.advanced_processor
186
- model = self.advanced_model
187
- model_name = self.advanced_model_name
188
- else:
189
- if not self.initialized["lightweight"]:
190
- self.initialize_lightweight_model()
191
-
192
- processor = self.lightweight_processor
193
- model = self.lightweight_model
194
- model_name = self.lightweight_model_name
195
-
196
- # Process image
197
- inputs = processor(image, return_tensors="pt")
198
-
199
- # Request token budget if available
200
- if self.token_manager:
201
- # Estimate token usage (approximate)
202
- estimated_tokens = 50 # Base tokens for generation
203
- approved, reason = self.token_manager.request_tokens(
204
- agent_name, "image_captioning", "", model_name)
205
-
206
- if not approved:
207
- self.logger.warning(f"Token budget exceeded: {reason}")
208
- return {"caption": "Token budget exceeded", "error": reason}
209
-
210
- # Generate caption
211
- with torch.no_grad():
212
- if model_type == "advanced":
213
- generated_ids = model.generate(
214
- pixel_values=inputs.pixel_values,
215
- max_length=30,
216
- num_beams=5
217
- )
218
- caption = processor.decode(generated_ids[0], skip_special_tokens=True)
219
- else:
220
- outputs = model.generate(**inputs, max_length=30, num_beams=5)
221
- caption = processor.decode(outputs[0], skip_special_tokens=True)
222
-
223
- # Prepare result
224
- result = {
225
- "caption": caption,
226
- "model_used": model_type,
227
- "complexity": self.determine_image_complexity(image),
228
- "confidence": 0.9 if model_type == "advanced" else 0.7 # Estimated confidence
229
- }
230
-
231
- # Log token usage if available
232
- if self.token_manager:
233
- # Approximate token count based on output length
234
- token_count = len(caption.split()) + 20 # Base tokens + output
235
- self.token_manager.log_usage(
236
- agent_name, "image_captioning", token_count, model_name)
237
-
238
- # Log energy usage if metrics calculator is available
239
- if self.metrics_calculator:
240
- energy_usage = self.token_manager.calculate_energy_usage(
241
- token_count, model_name)
242
- self.metrics_calculator.log_energy_usage(
243
- energy_usage, model_name, agent_name, "image_captioning")
244
-
245
- # Store in cache if available
246
- if self.cache_manager:
247
- self.cache_manager.put(image_hash, result, namespace="image_captions")
248
-
249
- return result
250
-
251
- def match_images_to_topic(self, topic: str, image_captions: List[Dict[str, Any]],
252
- text_model_manager=None) -> List[float]:
253
- """
254
- Match image captions to the user's topic using semantic similarity.
255
- Returns relevance scores for each image.
256
- """
257
- if not text_model_manager:
258
- self.logger.warning("No text model manager provided for semantic matching")
259
- return [0.5] * len(image_captions) # Default mid-range relevance
260
-
261
- # Extract captions
262
- captions = [item["caption"] for item in image_captions]
263
-
264
- # Use text model to compute similarity
265
- similarities = text_model_manager.compute_similarity(topic, captions)
266
-
267
- return similarities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/image_models.py
2
+ import logging
3
+ import os
4
+ import torch
5
+ from typing import Dict, List, Optional, Tuple, Union, Any
6
+ from PIL import Image
7
+ import numpy as np
8
+ from transformers import BlipProcessor, BlipForConditionalGeneration
9
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
10
+
11
+ class ImageModelManager:
12
+ def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
13
+ """Initialize the ImageModelManager with optional utilities."""
14
+ self.logger = logging.getLogger(__name__)
15
+ self.token_manager = token_manager
16
+ self.cache_manager = cache_manager
17
+ self.metrics_calculator = metrics_calculator
18
+
19
+ # Model instances
20
+ self.lightweight_model = None
21
+ self.lightweight_processor = None
22
+ self.advanced_model = None
23
+ self.advanced_processor = None
24
+
25
+ # Model names
26
+ self.lightweight_model_name = "Salesforce/blip-image-captioning-base"
27
+ self.advanced_model_name = "Salesforce/blip2-opt-2.7b"
28
+
29
+ # Track initialization state
30
+ self.initialized = {
31
+ "lightweight": False,
32
+ "advanced": False
33
+ }
34
+
35
+ # Default complexity thresholds
36
+ self.complexity_thresholds = {
37
+ "entropy": 4.5, # Higher entropy suggests more complex image
38
+ "edge_density": 0.15, # Higher edge density suggests more details
39
+ "size": 500000 # Larger images may contain more information
40
+ }
41
+
42
+ def initialize_lightweight_model(self):
43
+ """Initialize the lightweight image captioning model."""
44
+ if self.initialized["lightweight"]:
45
+ return
46
+
47
+ try:
48
+ # Register with token manager if available
49
+ if self.token_manager:
50
+ self.token_manager.register_model(
51
+ self.lightweight_model_name, "image_captioning")
52
+
53
+ # Load model and processor
54
+ self.logger.info(f"Loading lightweight image model: {self.lightweight_model_name}")
55
+ self.lightweight_processor = BlipProcessor.from_pretrained(self.lightweight_model_name)
56
+ self.lightweight_model = BlipForConditionalGeneration.from_pretrained(
57
+ self.lightweight_model_name, torch_dtype=torch.float32)
58
+
59
+ self.initialized["lightweight"] = True
60
+ self.logger.info("Lightweight image model initialized successfully")
61
+
62
+ except Exception as e:
63
+ self.logger.error(f"Failed to initialize lightweight image model: {e}")
64
+ raise
65
+
66
+ def initialize_advanced_model(self):
67
+ """Initialize the advanced image captioning model."""
68
+ if self.initialized["advanced"]:
69
+ return
70
+
71
+ try:
72
+ # Register with token manager if available
73
+ if self.token_manager:
74
+ self.token_manager.register_model(
75
+ self.advanced_model_name, "image_captioning")
76
+
77
+ # Load model and processor
78
+ self.logger.info(f"Loading advanced image model: {self.advanced_model_name}")
79
+ self.advanced_processor = Blip2Processor.from_pretrained(self.advanced_model_name)
80
+ self.advanced_model = Blip2ForConditionalGeneration.from_pretrained(
81
+ self.advanced_model_name, torch_dtype=torch.float16)
82
+
83
+ self.initialized["advanced"] = True
84
+ self.logger.info("Advanced image model initialized successfully")
85
+
86
+ except Exception as e:
87
+ self.logger.error(f"Failed to initialize advanced image model: {e}")
88
+ raise
89
+
90
+ def determine_image_complexity(self, image: Image.Image) -> Dict[str, float]:
91
+ """
92
+ Determine the complexity of an image to guide model selection.
93
+ Returns complexity metrics.
94
+ """
95
+ # Convert to numpy array
96
+ img_array = np.array(image.convert("L")) # Convert to grayscale for analysis
97
+
98
+ # Calculate image entropy (measure of randomness/information)
99
+ histogram = np.histogram(img_array, bins=256, range=(0, 256))[0]
100
+ histogram = histogram / histogram.sum()
101
+ non_zero = histogram > 0
102
+ entropy = -np.sum(histogram[non_zero] * np.log2(histogram[non_zero]))
103
+
104
+ # Calculate edge density using simple gradient method
105
+ gradient_x = np.abs(np.diff(img_array, axis=1, prepend=0))
106
+ gradient_y = np.abs(np.diff(img_array, axis=0, prepend=0))
107
+ gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
108
+ edge_density = np.mean(gradient_magnitude > 30) # Threshold for edge detection
109
+
110
+ # Get image size in pixels
111
+ size = image.width * image.height
112
+
113
+ return {
114
+ "entropy": float(entropy),
115
+ "edge_density": float(edge_density),
116
+ "size": size
117
+ }
118
+
119
+ def select_captioning_model(self, image: Image.Image) -> str:
120
+ """
121
+ Select the appropriate captioning model based on image complexity.
122
+ Returns model type ("lightweight" or "advanced").
123
+ """
124
+ # Get complexity metrics
125
+ complexity = self.determine_image_complexity(image)
126
+
127
+ # Decision logic for model selection
128
+ use_advanced = (
129
+ complexity["entropy"] > self.complexity_thresholds["entropy"] or
130
+ complexity["edge_density"] > self.complexity_thresholds["edge_density"] or
131
+ complexity["size"] > self.complexity_thresholds["size"]
132
+ )
133
+
134
+ # Log selection decision
135
+ model_type = "advanced" if use_advanced else "lightweight"
136
+ self.logger.info(f"Selected {model_type} model for image captioning (complexity: {complexity})")
137
+
138
+ # If metrics calculator is available, log model selection
139
+ if use_advanced and self.metrics_calculator:
140
+ # Estimate energy saved if we had used the advanced model
141
+ # This is a negative number since we're using more energy
142
+ energy_diff = -0.01 # Approximate difference in watt-hours
143
+ self.metrics_calculator.log_model_downgrade(
144
+ self.advanced_model_name, self.lightweight_model_name, energy_diff)
145
+
146
+ return model_type
147
+
148
+ def generate_image_caption(self, image: Union[str, Image.Image],
149
+ agent_name: str = "image_processing") -> Dict[str, Any]:
150
+ """
151
+ Generate caption for an image, selecting appropriate model based on complexity.
152
+ Returns caption and metadata.
153
+ """
154
+ # Handle string input (file path)
155
+ if isinstance(image, str):
156
+ if os.path.exists(image):
157
+ image = Image.open(image).convert('RGB')
158
+ else:
159
+ raise ValueError(f"Image file not found: {image}")
160
+
161
+ # Ensure image is PIL Image
162
+ if not isinstance(image, Image.Image):
163
+ raise TypeError("Image must be a PIL Image or a valid file path")
164
+
165
+ # Check cache if available
166
+ image_hash = str(hash(image.tobytes()))
167
+ if self.cache_manager:
168
+ cache_hit, cached_result = self.cache_manager.get(
169
+ image_hash, namespace="image_captions")
170
+
171
+ if cache_hit:
172
+ # Update metrics if available
173
+ if self.metrics_calculator:
174
+ self.metrics_calculator.update_cache_metrics(1, 0, 0.01) # Estimated energy saving
175
+ return cached_result
176
+
177
+ # Select model based on image complexity
178
+ model_type = self.select_captioning_model(image)
179
+
180
+ # Initialize selected model if needed
181
+ if model_type == "advanced":
182
+ if not self.initialized["advanced"]:
183
+ self.initialize_advanced_model()
184
+
185
+ processor = self.advanced_processor
186
+ model = self.advanced_model
187
+ model_name = self.advanced_model_name
188
+ else:
189
+ if not self.initialized["lightweight"]:
190
+ self.initialize_lightweight_model()
191
+
192
+ processor = self.lightweight_processor
193
+ model = self.lightweight_model
194
+ model_name = self.lightweight_model_name
195
+
196
+ # Process image
197
+ inputs = processor(image, return_tensors="pt")
198
+
199
+ # Request token budget if available
200
+ if self.token_manager:
201
+ # Estimate token usage (approximate)
202
+ estimated_tokens = 50 # Base tokens for generation
203
+ approved, reason = self.token_manager.request_tokens(
204
+ agent_name, "image_captioning", "", model_name)
205
+
206
+ if not approved:
207
+ self.logger.warning(f"Token budget exceeded: {reason}")
208
+ return {"caption": "Token budget exceeded", "error": reason}
209
+
210
+
211
+ # Generate caption
212
+ with torch.no_grad():
213
+ if model_type == "advanced":
214
+ generated_ids = model.generate(
215
+ pixel_values=inputs.pixel_values,
216
+ max_new_tokens=50, # Using max_new_tokens instead of max_length
217
+ num_beams=5
218
+ )
219
+ caption = processor.decode(generated_ids[0], skip_special_tokens=True)
220
+ else:
221
+ outputs = model.generate(
222
+ **inputs,
223
+ max_new_tokens=50, # Using max_new_tokens instead of max_length
224
+ num_beams=5
225
+ )
226
+ caption = processor.decode(outputs[0], skip_special_tokens=True)
227
+ # # Generate caption
228
+ # with torch.no_grad():
229
+ # if model_type == "advanced":
230
+ # generated_ids = model.generate(
231
+ # pixel_values=inputs.pixel_values,
232
+ # max_length=30,
233
+ # num_beams=5
234
+ # )
235
+ # caption = processor.decode(generated_ids[0], skip_special_tokens=True)
236
+ # else:
237
+ # outputs = model.generate(**inputs, max_length=30, num_beams=5)
238
+ # caption = processor.decode(outputs[0], skip_special_tokens=True)
239
+
240
+ # Prepare result
241
+ result = {
242
+ "caption": caption,
243
+ "model_used": model_type,
244
+ "complexity": self.determine_image_complexity(image),
245
+ "confidence": 0.9 if model_type == "advanced" else 0.7 # Estimated confidence
246
+ }
247
+
248
+ # Log token usage if available
249
+ if self.token_manager:
250
+ # Approximate token count based on output length
251
+ token_count = len(caption.split()) + 20 # Base tokens + output
252
+ self.token_manager.log_usage(
253
+ agent_name, "image_captioning", token_count, model_name)
254
+
255
+ # Log energy usage if metrics calculator is available
256
+ if self.metrics_calculator:
257
+ energy_usage = self.token_manager.calculate_energy_usage(
258
+ token_count, model_name)
259
+ self.metrics_calculator.log_energy_usage(
260
+ energy_usage, model_name, agent_name, "image_captioning")
261
+
262
+ # Store in cache if available
263
+ if self.cache_manager:
264
+ self.cache_manager.put(image_hash, result, namespace="image_captions")
265
+
266
+ return result
267
+
268
+ def match_images_to_topic(self, topic: str, image_captions: List[Dict[str, Any]],
269
+ text_model_manager=None) -> List[float]:
270
+ """
271
+ Match image captions to the user's topic using semantic similarity.
272
+ Returns relevance scores for each image.
273
+ """
274
+ if not text_model_manager:
275
+ self.logger.warning("No text model manager provided for semantic matching")
276
+ return [0.5] * len(image_captions) # Default mid-range relevance
277
+
278
+ # Extract captions
279
+ captions = [item["caption"] for item in image_captions]
280
+
281
+ # Use text model to compute similarity
282
+ similarities = text_model_manager.compute_similarity(topic, captions)
283
+
284
+ return similarities