atharwaah1work commited on
Commit
1a35ebb
Β·
verified Β·
1 Parent(s): 15425df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -142
app.py CHANGED
@@ -1,60 +1,41 @@
1
- # app.py
2
  import torch
3
- from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
4
  import gradio as gr
5
  from PIL import Image
6
  import re
7
  from typing import List, Tuple
8
- import os
9
- import logging
10
-
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO)
13
- logger = logging.getLogger(__name__)
14
 
 
 
 
 
 
 
 
15
 
16
  class RiverPollutionAnalyzer:
17
  def __init__(self):
18
- # Check if CUDA is available
19
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
- torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
21
- logger.info(f"Using device: {self.device}, dtype: {torch_dtype}")
22
-
23
  try:
24
- # Load processor
25
- logger.info("Loading processor...")
26
  self.processor = InstructBlipProcessor.from_pretrained(
27
- "Salesforce/instructblip-vicuna-7b"
 
28
  )
29
-
30
- # Load model with appropriate settings
31
- logger.info("Loading model...")
32
  self.model = InstructBlipForConditionalGeneration.from_pretrained(
33
- "Salesforce/instructblip-vicuna-7b",
 
34
  device_map="auto",
35
- torch_dtype=torch_dtype,
36
- load_in_8bit=True if self.device == "cuda" else False,
37
- offload_folder="./offload",
38
  )
39
- logger.info("Model loaded successfully")
40
-
41
  except Exception as e:
42
- logger.error(f"Error loading model: {str(e)}")
43
- raise RuntimeError(f"Failed to initialize model: {str(e)}")
44
 
45
  self.pollutants = [
46
- "plastic waste",
47
- "chemical foam",
48
- "industrial discharge",
49
- "sewage water",
50
- "oil spill",
51
- "organic debris",
52
- "construction waste",
53
- "medical waste",
54
- "floating trash",
55
- "algal bloom",
56
- "toxic sludge",
57
- "agricultural runoff",
58
  ]
59
 
60
  self.severity_descriptions = {
@@ -67,16 +48,15 @@ class RiverPollutionAnalyzer:
67
  7: "Very severe pollution - Major ecosystem impact",
68
  8: "Extreme pollution - Dangerous levels",
69
  9: "Critical pollution - Immediate action needed",
70
- 10: "Disaster level - Ecological catastrophe",
71
  }
72
 
73
  def analyze_image(self, image):
74
  """Analyze river pollution with robust parsing"""
75
- try:
76
- if not isinstance(image, Image.Image):
77
- image = Image.fromarray(image)
78
 
79
- prompt = """Analyze this river pollution scene and provide:
80
  1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
81
  2. Estimate pollution severity from 1-10
82
 
@@ -84,9 +64,12 @@ Respond EXACTLY in this format:
84
  Pollutants: [comma separated list]
85
  Severity: [number]"""
86
 
87
- inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(
88
- self.device
89
- )
 
 
 
90
 
91
  with torch.no_grad():
92
  outputs = self.model.generate(
@@ -94,38 +77,39 @@ Severity: [number]"""
94
  max_new_tokens=200,
95
  temperature=0.5,
96
  top_p=0.85,
97
- do_sample=True,
98
  )
99
 
100
  analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
101
  pollutants, severity = self._parse_response(analysis)
102
  return self._format_analysis(pollutants, severity)
103
-
104
  except Exception as e:
105
- logger.error(f"Error analyzing image: {str(e)}")
106
- return "⚠️ Error analyzing image. Please try again or check the logs for details."
107
 
 
108
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
109
- """Robust parsing of model response"""
110
  pollutants = []
111
  severity = 3
112
 
113
  # Extract pollutants
114
  pollutant_match = re.search(
115
- r"(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)",
116
- analysis,
117
  )
118
 
119
  if pollutant_match:
120
  pollutants_str = pollutant_match.group(2).strip()
121
  pollutants = [
122
  p.strip().lower()
123
- for p in re.split(r"[,;]|\band\b", pollutants_str)
124
  if p.strip().lower() in self.pollutants
125
  ]
126
 
127
  # Extract severity
128
- severity_match = re.search(r"(?i)(severity|level)[:\s]*(\d{1,2})", analysis)
 
 
 
129
 
130
  if severity_match:
131
  try:
@@ -138,72 +122,46 @@ Severity: [number]"""
138
  return pollutants, severity
139
 
140
  def _calculate_severity(self, pollutants: List[str]) -> int:
141
- """Weighted severity calculation"""
142
  if not pollutants:
143
  return 1
144
 
145
  weights = {
146
- "medical waste": 3,
147
- "toxic sludge": 3,
148
- "oil spill": 2.5,
149
- "chemical foam": 2,
150
- "industrial discharge": 2,
151
- "sewage water": 2,
152
- "plastic waste": 1.5,
153
- "construction waste": 1.5,
154
- "algal bloom": 1.5,
155
- "agricultural runoff": 1.5,
156
- "floating trash": 1,
157
- "organic debris": 1,
158
  }
159
 
160
  avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants)
161
  return min(10, max(1, round(avg_weight * 3)))
162
 
163
  def _format_analysis(self, pollutants: List[str], severity: int) -> str:
164
- """Generate formatted report"""
165
  severity_bar = f"""πŸ“Š Severity: {severity}/10
166
  {"β–ˆ" * severity}{"β–‘" * (10 - severity)}
167
- {self.severity_descriptions.get(severity, "")}"""
168
 
169
- pollutants_list = (
170
- "\nπŸ” No pollutants detected"
171
- if not pollutants
172
- else "\n".join(
173
- f"{i}. {p.capitalize()}" for i, p in enumerate(pollutants[:5], 1)
174
- )
175
- )
176
 
177
  return f"""🌊 River Pollution Analysis 🌊
178
  {pollutants_list}
179
  {severity_bar}"""
180
 
181
  def analyze_chat(self, message: str) -> str:
182
- """Handle chat questions about pollution"""
183
- try:
184
- message = message.lower().strip()
185
- if any(word in message for word in ["hello", "hi", "hey"]):
186
- return "Hello! I'm a river pollution analyzer. Ask me about pollution types or upload an image for analysis."
187
- elif "pollution" in message:
188
- return "Common river pollutants include: plastic waste, chemical foam, industrial discharge, sewage water, and oil spills."
189
- elif "severity" in message:
190
- return "Severity is rated 1-10 (1=minimal, 10=disaster). It considers pollutant type and quantity."
191
- elif "help" in message:
192
- return "I can: 1) Analyze river pollution in images 2) Answer pollution questions. Try uploading an image or asking about pollution types."
193
- else:
194
- return "I specialize in river pollution analysis. Try asking about pollution types or upload an image for analysis."
195
- except Exception as e:
196
- logger.error(f"Error in chat: {str(e)}")
197
- return "Sorry, I encountered an error. Please try again."
198
-
199
 
200
- # Initialize analyzer
201
  try:
202
  analyzer = RiverPollutionAnalyzer()
203
- logger.info("Analyzer initialized successfully")
204
  except Exception as e:
205
- logger.error(f"Failed to initialize analyzer: {str(e)}")
206
- raise
207
 
208
  css = """
209
  .header {
@@ -213,16 +171,13 @@ css = """
213
  border-radius: 10px;
214
  margin-bottom: 20px;
215
  }
216
-
217
  .side-by-side {
218
  display: flex;
219
  gap: 20px;
220
  }
221
-
222
  .left-panel, .right-panel {
223
  flex: 1;
224
  }
225
-
226
  .analysis-box {
227
  padding: 20px;
228
  background: #f8f9fa;
@@ -230,97 +185,72 @@ css = """
230
  margin-top: 20px;
231
  border: 1px solid #dee2e6;
232
  }
233
-
234
  .chat-container {
235
  background: #f8f9fa;
236
  padding: 20px;
237
  border-radius: 10px;
238
  height: 100%;
239
  }
240
-
241
- .pollution-icon {
242
- font-size: 24px;
243
- margin-right: 10px;
244
- }
245
-
246
- .severity-bar {
247
- font-family: monospace;
248
- font-size: 16px;
249
- }
250
-
251
- .error-message {
252
- color: #dc3545;
253
- font-weight: bold;
254
- }
255
  """
256
 
257
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
258
  with gr.Column(elem_classes="header"):
259
  gr.Markdown("# 🌍 River Pollution Analyzer")
260
- gr.Markdown("### AI-powered water pollution detection")
261
 
262
  with gr.Row(elem_classes="side-by-side"):
263
- # Left Panel
264
  with gr.Column(elem_classes="left-panel"):
265
  with gr.Group():
266
- image_input = gr.Image(
267
- type="pil", label="Upload River Image", height=300
268
- )
269
  analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
270
-
271
  with gr.Group(elem_classes="analysis-box"):
272
  gr.Markdown("### πŸ“Š Analysis report")
273
  analysis_output = gr.Markdown()
274
 
275
- # Right Panel
276
  with gr.Column(elem_classes="right-panel"):
277
  with gr.Group(elem_classes="chat-container"):
278
- chatbot = gr.Chatbot(
279
- label="Pollution Analysis Q&A", height=400, bubble_full_width=False
280
- )
281
  with gr.Row():
282
  chat_input = gr.Textbox(
283
  placeholder="Ask about pollution sources...",
284
  label="Your Question",
285
  container=False,
286
- scale=5,
287
  )
288
  chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
289
  clear_btn = gr.Button("🧹 Clear Chat History", size="sm")
290
 
291
  analyze_btn.click(
292
- analyzer.analyze_image, inputs=image_input, outputs=analysis_output
 
 
293
  )
294
 
295
  chat_input.submit(
296
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
297
  inputs=[chat_input, chatbot],
298
- outputs=[chat_input, chatbot],
299
  )
300
 
301
  chat_btn.click(
302
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
303
  inputs=[chat_input, chatbot],
304
- outputs=[chat_input, chatbot],
305
  )
306
 
307
  clear_btn.click(lambda: None, outputs=[chatbot])
308
 
309
  gr.Examples(
310
  examples=[
311
- [
312
- "https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river1.jpg"
313
- ],
314
- [
315
- "https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river2.jpg"
316
- ],
317
  ],
318
  inputs=image_input,
319
  outputs=analysis_output,
320
- fn=analyzer.analyze_image,
321
  cache_examples=True,
322
- label="Try example images:",
323
  )
324
 
325
- if __name__ == "__main__":
326
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import torch
2
+ from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig
3
  import gradio as gr
4
  from PIL import Image
5
  import re
6
  from typing import List, Tuple
 
 
 
 
 
 
7
 
8
+ # Configuration for 4-bit quantization
9
+ quant_config = BitsAndBytesConfig(
10
+ load_in_4bit=True,
11
+ bnb_4bit_compute_dtype=torch.float16,
12
+ bnb_4bit_quant_type="nf4",
13
+ bnb_4bit_use_double_quant=True
14
+ )
15
 
16
  class RiverPollutionAnalyzer:
17
  def __init__(self):
 
 
 
 
 
18
  try:
19
+ # Initialize InstructBLIP-FLAN-T5-XL with 4-bit quantization
 
20
  self.processor = InstructBlipProcessor.from_pretrained(
21
+ "Salesforce/instructblip-flan-t5-xl",
22
+ cache_dir="model_cache"
23
  )
 
 
 
24
  self.model = InstructBlipForConditionalGeneration.from_pretrained(
25
+ "Salesforce/instructblip-flan-t5-xl",
26
+ quantization_config=quant_config,
27
  device_map="auto",
28
+ torch_dtype=torch.float16,
29
+ cache_dir="model_cache"
 
30
  )
 
 
31
  except Exception as e:
32
+ raise RuntimeError(f"Model loading failed: {str(e)}")
 
33
 
34
  self.pollutants = [
35
+ "plastic waste", "chemical foam", "industrial discharge",
36
+ "sewage water", "oil spill", "organic debris",
37
+ "construction waste", "medical waste", "floating trash",
38
+ "algal bloom", "toxic sludge", "agricultural runoff"
 
 
 
 
 
 
 
 
39
  ]
40
 
41
  self.severity_descriptions = {
 
48
  7: "Very severe pollution - Major ecosystem impact",
49
  8: "Extreme pollution - Dangerous levels",
50
  9: "Critical pollution - Immediate action needed",
51
+ 10: "Disaster level - Ecological catastrophe"
52
  }
53
 
54
  def analyze_image(self, image):
55
  """Analyze river pollution with robust parsing"""
56
+ if not isinstance(image, Image.Image):
57
+ image = Image.fromarray(image)
 
58
 
59
+ prompt = """Analyze this river pollution scene and provide:
60
  1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
61
  2. Estimate pollution severity from 1-10
62
 
 
64
  Pollutants: [comma separated list]
65
  Severity: [number]"""
66
 
67
+ try:
68
+ inputs = self.processor(
69
+ images=image,
70
+ text=prompt,
71
+ return_tensors="pt"
72
+ ).to(self.model.device)
73
 
74
  with torch.no_grad():
75
  outputs = self.model.generate(
 
77
  max_new_tokens=200,
78
  temperature=0.5,
79
  top_p=0.85,
80
+ do_sample=True
81
  )
82
 
83
  analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
84
  pollutants, severity = self._parse_response(analysis)
85
  return self._format_analysis(pollutants, severity)
 
86
  except Exception as e:
87
+ return f"⚠️ Analysis failed: {str(e)}"
 
88
 
89
+ # [Keep your existing parsing/formatting methods]
90
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
 
91
  pollutants = []
92
  severity = 3
93
 
94
  # Extract pollutants
95
  pollutant_match = re.search(
96
+ r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)',
97
+ analysis
98
  )
99
 
100
  if pollutant_match:
101
  pollutants_str = pollutant_match.group(2).strip()
102
  pollutants = [
103
  p.strip().lower()
104
+ for p in re.split(r'[,;]|\band\b', pollutants_str)
105
  if p.strip().lower() in self.pollutants
106
  ]
107
 
108
  # Extract severity
109
+ severity_match = re.search(
110
+ r'(?i)(severity|level)[:\s]*(\d{1,2})',
111
+ analysis
112
+ )
113
 
114
  if severity_match:
115
  try:
 
122
  return pollutants, severity
123
 
124
  def _calculate_severity(self, pollutants: List[str]) -> int:
 
125
  if not pollutants:
126
  return 1
127
 
128
  weights = {
129
+ "medical waste": 3, "toxic sludge": 3, "oil spill": 2.5,
130
+ "chemical foam": 2, "industrial discharge": 2, "sewage water": 2,
131
+ "plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5,
132
+ "agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1
 
 
 
 
 
 
 
 
133
  }
134
 
135
  avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants)
136
  return min(10, max(1, round(avg_weight * 3)))
137
 
138
  def _format_analysis(self, pollutants: List[str], severity: int) -> str:
 
139
  severity_bar = f"""πŸ“Š Severity: {severity}/10
140
  {"β–ˆ" * severity}{"β–‘" * (10 - severity)}
141
+ {self.severity_descriptions.get(severity, '')}"""
142
 
143
+ pollutants_list = "\nπŸ” No pollutants detected" if not pollutants else "\n".join(
144
+ f"{i}. {p.capitalize()}" for i, p in enumerate(pollutants[:5], 1))
 
 
 
 
 
145
 
146
  return f"""🌊 River Pollution Analysis 🌊
147
  {pollutants_list}
148
  {severity_bar}"""
149
 
150
  def analyze_chat(self, message: str) -> str:
151
+ if any(word in message.lower() for word in ["hello", "hi", "hey"]):
152
+ return "Hello! I'm a river pollution analyzer. Ask me about pollution types or upload an image for analysis."
153
+ elif "pollution" in message.lower():
154
+ return "Common river pollutants include: plastic waste, chemical foam, industrial discharge, sewage water, and oil spills."
155
+ else:
156
+ return "I can answer questions about river pollution. Try asking about pollution types or upload an image for analysis."
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ # Initialize with error handling
159
  try:
160
  analyzer = RiverPollutionAnalyzer()
161
+ model_status = "βœ… Model loaded successfully"
162
  except Exception as e:
163
+ analyzer = None
164
+ model_status = f"❌ Model loading failed: {str(e)}"
165
 
166
  css = """
167
  .header {
 
171
  border-radius: 10px;
172
  margin-bottom: 20px;
173
  }
 
174
  .side-by-side {
175
  display: flex;
176
  gap: 20px;
177
  }
 
178
  .left-panel, .right-panel {
179
  flex: 1;
180
  }
 
181
  .analysis-box {
182
  padding: 20px;
183
  background: #f8f9fa;
 
185
  margin-top: 20px;
186
  border: 1px solid #dee2e6;
187
  }
 
188
  .chat-container {
189
  background: #f8f9fa;
190
  padding: 20px;
191
  border-radius: 10px;
192
  height: 100%;
193
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  """
195
 
196
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
197
  with gr.Column(elem_classes="header"):
198
  gr.Markdown("# 🌍 River Pollution Analyzer")
199
+ gr.Markdown(f"### {model_status}")
200
 
201
  with gr.Row(elem_classes="side-by-side"):
 
202
  with gr.Column(elem_classes="left-panel"):
203
  with gr.Group():
204
+ image_input = gr.Image(type="pil", label="Upload River Image", height=300)
 
 
205
  analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
206
+
207
  with gr.Group(elem_classes="analysis-box"):
208
  gr.Markdown("### πŸ“Š Analysis report")
209
  analysis_output = gr.Markdown()
210
 
 
211
  with gr.Column(elem_classes="right-panel"):
212
  with gr.Group(elem_classes="chat-container"):
213
+ chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400)
 
 
214
  with gr.Row():
215
  chat_input = gr.Textbox(
216
  placeholder="Ask about pollution sources...",
217
  label="Your Question",
218
  container=False,
219
+ scale=5
220
  )
221
  chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
222
  clear_btn = gr.Button("🧹 Clear Chat History", size="sm")
223
 
224
  analyze_btn.click(
225
+ analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
226
+ inputs=image_input,
227
+ outputs=analysis_output
228
  )
229
 
230
  chat_input.submit(
231
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
232
  inputs=[chat_input, chatbot],
233
+ outputs=[chat_input, chatbot]
234
  )
235
 
236
  chat_btn.click(
237
  lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
238
  inputs=[chat_input, chatbot],
239
+ outputs=[chat_input, chatbot]
240
  )
241
 
242
  clear_btn.click(lambda: None, outputs=[chatbot])
243
 
244
  gr.Examples(
245
  examples=[
246
+ ["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river1.jpg"],
247
+ ["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river2.jpg"]
 
 
 
 
248
  ],
249
  inputs=image_input,
250
  outputs=analysis_output,
251
+ fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
252
  cache_examples=True,
253
+ label="Try example images:"
254
  )
255
 
256
+ demo.queue(max_size=3).launch()