atharwaah1work commited on
Commit
3a7adc5
Β·
verified Β·
1 Parent(s): ff7ee96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -70
app.py CHANGED
@@ -1,11 +1,16 @@
 
1
  import torch
2
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig
3
  import gradio as gr
4
  from PIL import Image
5
  import re
 
6
  from typing import List, Tuple
7
 
8
- # Configuration for 4-bit quantization
 
 
 
9
  quant_config = BitsAndBytesConfig(
10
  load_in_4bit=True,
11
  bnb_4bit_compute_dtype=torch.float16,
@@ -16,7 +21,7 @@ quant_config = BitsAndBytesConfig(
16
  class RiverPollutionAnalyzer:
17
  def __init__(self):
18
  try:
19
- # Initialize InstructBLIP-FLAN-T5-XL with 4-bit quantization
20
  self.processor = InstructBlipProcessor.from_pretrained(
21
  "Salesforce/instructblip-flan-t5-xl",
22
  cache_dir="model_cache"
@@ -28,8 +33,12 @@ class RiverPollutionAnalyzer:
28
  torch_dtype=torch.float16,
29
  cache_dir="model_cache"
30
  )
 
 
31
  except Exception as e:
32
- raise RuntimeError(f"Model loading failed: {str(e)}")
 
 
33
 
34
  self.pollutants = [
35
  "plastic waste", "chemical foam", "industrial discharge",
@@ -52,15 +61,17 @@ class RiverPollutionAnalyzer:
52
  }
53
 
54
  def analyze_image(self, image):
55
- """Analyze river pollution with robust parsing"""
 
 
56
  if not isinstance(image, Image.Image):
57
  image = Image.fromarray(image)
58
 
59
- prompt = """Analyze this river pollution scene and provide:
60
- 1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
61
- 2. Estimate pollution severity from 1-10
62
 
63
- Respond EXACTLY in this format:
64
  Pollutants: [comma separated list]
65
  Severity: [number]"""
66
 
@@ -71,22 +82,19 @@ Severity: [number]"""
71
  return_tensors="pt"
72
  ).to(self.model.device)
73
 
74
- with torch.no_grad():
75
- outputs = self.model.generate(
76
- **inputs,
77
- max_new_tokens=200,
78
- temperature=0.5,
79
- top_p=0.85,
80
- do_sample=True
81
- )
82
 
83
  analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
84
  pollutants, severity = self._parse_response(analysis)
85
  return self._format_analysis(pollutants, severity)
86
  except Exception as e:
87
- return f"⚠️ Analysis failed: {str(e)}"
88
 
89
- # [Keep your existing parsing/formatting methods]
90
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
91
  pollutants = []
92
  severity = 3
@@ -96,21 +104,16 @@ Severity: [number]"""
96
  r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)',
97
  analysis
98
  )
99
-
100
  if pollutant_match:
101
  pollutants_str = pollutant_match.group(2).strip()
102
  pollutants = [
103
- p.strip().lower()
104
  for p in re.split(r'[,;]|\band\b', pollutants_str)
105
  if p.strip().lower() in self.pollutants
106
  ]
107
 
108
  # Extract severity
109
- severity_match = re.search(
110
- r'(?i)(severity|level)[:\s]*(\d{1,2})',
111
- analysis
112
- )
113
-
114
  if severity_match:
115
  try:
116
  severity = min(max(int(severity_match.group(2)), 1), 10)
@@ -124,14 +127,12 @@ Severity: [number]"""
124
  def _calculate_severity(self, pollutants: List[str]) -> int:
125
  if not pollutants:
126
  return 1
127
-
128
  weights = {
129
  "medical waste": 3, "toxic sludge": 3, "oil spill": 2.5,
130
  "chemical foam": 2, "industrial discharge": 2, "sewage water": 2,
131
  "plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5,
132
  "agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1
133
  }
134
-
135
  avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants)
136
  return min(10, max(1, round(avg_weight * 3)))
137
 
@@ -147,21 +148,8 @@ Severity: [number]"""
147
  {pollutants_list}
148
  {severity_bar}"""
149
 
150
- def analyze_chat(self, message: str) -> str:
151
- if any(word in message.lower() for word in ["hello", "hi", "hey"]):
152
- return "Hello! I'm a river pollution analyzer. Ask me about pollution types or upload an image for analysis."
153
- elif "pollution" in message.lower():
154
- return "Common river pollutants include: plastic waste, chemical foam, industrial discharge, sewage water, and oil spills."
155
- else:
156
- return "I can answer questions about river pollution. Try asking about pollution types or upload an image for analysis."
157
-
158
- # Initialize with error handling
159
- try:
160
- analyzer = RiverPollutionAnalyzer()
161
- model_status = "βœ… Model loaded successfully"
162
- except Exception as e:
163
- analyzer = None
164
- model_status = f"❌ Model loading failed: {str(e)}"
165
 
166
  css = """
167
  .header {
@@ -191,66 +179,58 @@ css = """
191
  border-radius: 10px;
192
  height: 100%;
193
  }
 
 
 
 
194
  """
195
 
196
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
197
  with gr.Column(elem_classes="header"):
198
  gr.Markdown("# 🌍 River Pollution Analyzer")
199
- gr.Markdown(f"### {model_status}")
200
 
201
  with gr.Row(elem_classes="side-by-side"):
 
202
  with gr.Column(elem_classes="left-panel"):
203
  with gr.Group():
204
  image_input = gr.Image(type="pil", label="Upload River Image", height=300)
205
  analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
206
 
207
  with gr.Group(elem_classes="analysis-box"):
208
- gr.Markdown("### πŸ“Š Analysis report")
209
  analysis_output = gr.Markdown()
210
 
 
211
  with gr.Column(elem_classes="right-panel"):
212
  with gr.Group(elem_classes="chat-container"):
213
- chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400)
 
214
  with gr.Row():
215
  chat_input = gr.Textbox(
216
- placeholder="Ask about pollution sources...",
217
- label="Your Question",
218
  container=False,
219
  scale=5
220
  )
221
- chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
222
- clear_btn = gr.Button("🧹 Clear Chat History", size="sm")
223
 
224
  analyze_btn.click(
225
- analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
226
  inputs=image_input,
227
  outputs=analysis_output
228
  )
229
 
230
- chat_input.submit(
231
- lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
232
- inputs=[chat_input, chatbot],
233
- outputs=[chat_input, chatbot]
234
- )
235
-
236
- chat_btn.click(
237
- lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
238
- inputs=[chat_input, chatbot],
239
- outputs=[chat_input, chatbot]
240
- )
241
-
242
- clear_btn.click(lambda: None, outputs=[chatbot])
243
-
244
  gr.Examples(
245
  examples=[
246
- ["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river1.jpg"],
247
- ["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river2.jpg"]
248
  ],
249
  inputs=image_input,
250
  outputs=analysis_output,
251
- fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
252
- cache_examples=True,
253
- label="Try example images:"
254
  )
255
 
256
- demo.queue(max_size=3).launch()
 
1
+ # app.py
2
  import torch
3
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig
4
  import gradio as gr
5
  from PIL import Image
6
  import re
7
+ import os
8
  from typing import List, Tuple
9
 
10
+ # Create model cache directory
11
+ os.makedirs("model_cache", exist_ok=True)
12
+
13
+ # 4-bit quantization config
14
  quant_config = BitsAndBytesConfig(
15
  load_in_4bit=True,
16
  bnb_4bit_compute_dtype=torch.float16,
 
21
  class RiverPollutionAnalyzer:
22
  def __init__(self):
23
  try:
24
+ # Load InstructBLIP-FLAN-T5-XL with 4-bit quantization
25
  self.processor = InstructBlipProcessor.from_pretrained(
26
  "Salesforce/instructblip-flan-t5-xl",
27
  cache_dir="model_cache"
 
33
  torch_dtype=torch.float16,
34
  cache_dir="model_cache"
35
  )
36
+ self.model_loaded = True
37
+ self.status = "βœ… Model loaded successfully"
38
  except Exception as e:
39
+ self.model_loaded = False
40
+ self.status = f"❌ Model loading failed: {str(e)}"
41
+ print(self.status)
42
 
43
  self.pollutants = [
44
  "plastic waste", "chemical foam", "industrial discharge",
 
61
  }
62
 
63
  def analyze_image(self, image):
64
+ if not self.model_loaded:
65
+ return "Model not loaded. Please check logs."
66
+
67
  if not isinstance(image, Image.Image):
68
  image = Image.fromarray(image)
69
 
70
+ prompt = """Analyze this river pollution and list:
71
+ 1. Visible pollutants from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
72
+ 2. Severity estimate (1-10)
73
 
74
+ Respond EXACTLY as:
75
  Pollutants: [comma separated list]
76
  Severity: [number]"""
77
 
 
82
  return_tensors="pt"
83
  ).to(self.model.device)
84
 
85
+ outputs = self.model.generate(
86
+ **inputs,
87
+ max_new_tokens=150, # Reduced for memory
88
+ temperature=0.5,
89
+ top_p=0.85
90
+ )
 
 
91
 
92
  analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
93
  pollutants, severity = self._parse_response(analysis)
94
  return self._format_analysis(pollutants, severity)
95
  except Exception as e:
96
+ return f"⚠️ Analysis error: {str(e)}"
97
 
 
98
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
99
  pollutants = []
100
  severity = 3
 
104
  r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)',
105
  analysis
106
  )
 
107
  if pollutant_match:
108
  pollutants_str = pollutant_match.group(2).strip()
109
  pollutants = [
110
+ p.strip().lower()
111
  for p in re.split(r'[,;]|\band\b', pollutants_str)
112
  if p.strip().lower() in self.pollutants
113
  ]
114
 
115
  # Extract severity
116
+ severity_match = re.search(r'(?i)(severity|level)[:\s]*(\d{1,2})', analysis)
 
 
 
 
117
  if severity_match:
118
  try:
119
  severity = min(max(int(severity_match.group(2)), 1), 10)
 
127
  def _calculate_severity(self, pollutants: List[str]) -> int:
128
  if not pollutants:
129
  return 1
 
130
  weights = {
131
  "medical waste": 3, "toxic sludge": 3, "oil spill": 2.5,
132
  "chemical foam": 2, "industrial discharge": 2, "sewage water": 2,
133
  "plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5,
134
  "agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1
135
  }
 
136
  avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants)
137
  return min(10, max(1, round(avg_weight * 3)))
138
 
 
148
  {pollutants_list}
149
  {severity_bar}"""
150
 
151
+ # Initialize analyzer
152
+ analyzer = RiverPollutionAnalyzer()
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  css = """
155
  .header {
 
179
  border-radius: 10px;
180
  height: 100%;
181
  }
182
+ .dark .analysis-box, .dark .chat-container {
183
+ background: #2a2a2a;
184
+ border-color: #444;
185
+ }
186
  """
187
 
188
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
189
  with gr.Column(elem_classes="header"):
190
  gr.Markdown("# 🌍 River Pollution Analyzer")
191
+ gr.Markdown(f"### {analyzer.status}")
192
 
193
  with gr.Row(elem_classes="side-by-side"):
194
+ # Left Panel
195
  with gr.Column(elem_classes="left-panel"):
196
  with gr.Group():
197
  image_input = gr.Image(type="pil", label="Upload River Image", height=300)
198
  analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
199
 
200
  with gr.Group(elem_classes="analysis-box"):
201
+ gr.Markdown("### πŸ“Š Analysis Report")
202
  analysis_output = gr.Markdown()
203
 
204
+ # Right Panel
205
  with gr.Column(elem_classes="right-panel"):
206
  with gr.Group(elem_classes="chat-container"):
207
+ gr.Markdown("### πŸ’¬ Pollution Q&A")
208
+ chatbot = gr.Chatbot(height=400)
209
  with gr.Row():
210
  chat_input = gr.Textbox(
211
+ placeholder="Ask about pollution types...",
 
212
  container=False,
213
  scale=5
214
  )
215
+ chat_btn = gr.Button("Send", variant="secondary", scale=1)
216
+ clear_btn = gr.Button("Clear Chat", size="sm")
217
 
218
  analyze_btn.click(
219
+ analyzer.analyze_image,
220
  inputs=image_input,
221
  outputs=analysis_output
222
  )
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  gr.Examples(
225
  examples=[
226
+ ["examples/polluted_river1.jpg"],
227
+ ["examples/polluted_river2.jpg"]
228
  ],
229
  inputs=image_input,
230
  outputs=analysis_output,
231
+ fn=analyzer.analyze_image,
232
+ cache_examples=False, # Disabled for free tier
233
+ label="Example Images"
234
  )
235
 
236
+ demo.queue(max_size=2).launch()