atharwaah1work commited on
Commit
a0e1483
Β·
verified Β·
1 Parent(s): 1545304

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -58
app.py CHANGED
@@ -1,34 +1,53 @@
1
- # app.py - CPU-Compatible Version for Hugging Face Spaces
 
2
  import torch
3
- from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
4
  import gradio as gr
5
  from PIL import Image
6
  import re
7
  import os
8
  from typing import List, Tuple
9
 
10
- # Create model cache directory
11
- os.makedirs("model_cache", exist_ok=True)
 
 
 
 
 
12
 
13
  class RiverPollutionAnalyzer:
14
  def __init__(self):
15
  try:
16
- # Initialize model (CPU-compatible version)
17
  self.processor = InstructBlipProcessor.from_pretrained(
18
  "Salesforce/instructblip-flan-t5-xl",
19
  cache_dir="model_cache"
20
  )
21
- self.model = InstructBlipForConditionalGeneration.from_pretrained(
22
- "Salesforce/instructblip-flan-t5-xl",
23
- device_map="auto",
24
- torch_dtype=torch.float32, # Using float32 instead of quantization
25
- cache_dir="model_cache",
26
- low_cpu_mem_usage=True
27
- )
28
- self.model_loaded = True
29
- self.status = "βœ… Model loaded successfully (CPU mode)"
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  except Exception as e:
31
- self.model_loaded = False
32
  self.status = f"❌ Model loading failed: {str(e)}"
33
  print(self.status)
34
 
@@ -53,47 +72,49 @@ class RiverPollutionAnalyzer:
53
  }
54
 
55
  def analyze_image(self, image):
56
- """Analyze river pollution with CPU optimizations"""
57
- if not self.model_loaded:
58
  return "Model not loaded. Please check logs."
59
 
60
- try:
61
- if not isinstance(image, Image.Image):
62
- image = Image.fromarray(image)
63
-
64
- # Resize for CPU efficiency
65
- image = image.resize((512, 512))
66
-
67
- prompt = """Analyze this river pollution and list:
68
- 1. Visible pollutants from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
69
- 2. Severity estimate (1-10)
70
 
71
- Respond EXACTLY as:
 
 
 
 
72
  Pollutants: [comma separated list]
73
  Severity: [number]"""
74
 
 
75
  inputs = self.processor(
76
  images=image,
77
  text=prompt,
78
  return_tensors="pt"
79
- ) # No .to("cuda") for CPU
80
 
81
- outputs = self.model.generate(
82
- **inputs,
83
- max_new_tokens=100, # Reduced for CPU
84
- temperature=0.5,
85
- top_p=0.85
86
- )
 
 
87
 
88
  analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
89
  pollutants, severity = self._parse_response(analysis)
90
  return self._format_analysis(pollutants, severity)
91
-
92
  except Exception as e:
93
  return f"⚠️ Analysis error: {str(e)}"
94
 
 
95
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
96
- """Parse model response"""
97
  pollutants = []
98
  severity = 3
99
 
@@ -123,7 +144,7 @@ Severity: [number]"""
123
  return pollutants, severity
124
 
125
  def _calculate_severity(self, pollutants: List[str]) -> int:
126
- """Calculate weighted severity"""
127
  if not pollutants:
128
  return 1
129
 
@@ -138,7 +159,7 @@ Severity: [number]"""
138
  return min(10, max(1, round(avg_weight * 3)))
139
 
140
  def _format_analysis(self, pollutants: List[str], severity: int) -> str:
141
- """Generate formatted report"""
142
  severity_bar = f"""πŸ“Š Severity: {severity}/10
143
  {"β–ˆ" * severity}{"β–‘" * (10 - severity)}
144
  {self.severity_descriptions.get(severity, '')}"""
@@ -150,6 +171,15 @@ Severity: [number]"""
150
  {pollutants_list}
151
  {severity_bar}"""
152
 
 
 
 
 
 
 
 
 
 
153
  # Initialize analyzer
154
  analyzer = RiverPollutionAnalyzer()
155
 
@@ -186,9 +216,6 @@ css = """
186
  background: #2a2a2a;
187
  border-color: #444;
188
  }
189
- .btn-primary {
190
- margin-top: 10px;
191
- }
192
  """
193
 
194
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
@@ -200,15 +227,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
200
  # Left Panel
201
  with gr.Column(elem_classes="left-panel"):
202
  with gr.Group():
203
- image_input = gr.Image(
204
- type="pil",
205
- label="Upload River Image",
206
- height=300
207
- )
208
- analyze_btn = gr.Button(
209
- "πŸ” Analyze Pollution",
210
- variant="primary"
211
- )
212
 
213
  with gr.Group(elem_classes="analysis-box"):
214
  gr.Markdown("### πŸ“Š Analysis Report")
@@ -217,16 +237,16 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
217
  # Right Panel
218
  with gr.Column(elem_classes="right-panel"):
219
  with gr.Group(elem_classes="chat-container"):
220
- gr.Markdown("### πŸ’¬ Pollution Q&A")
221
- chatbot = gr.Chatbot(height=400)
222
  with gr.Row():
223
  chat_input = gr.Textbox(
224
- placeholder="Ask about pollution types...",
 
225
  container=False,
226
  scale=5
227
  )
228
- chat_btn = gr.Button("Send", variant="secondary", scale=1)
229
- clear_btn = gr.Button("Clear Chat", size="sm")
230
 
231
  analyze_btn.click(
232
  analyzer.analyze_image,
@@ -234,7 +254,20 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
234
  outputs=analysis_output
235
  )
236
 
237
- # Example images (host them in your Space's repository)
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  gr.Examples(
239
  examples=[
240
  ["examples/polluted_river1.jpg"],
@@ -243,7 +276,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
243
  inputs=image_input,
244
  outputs=analysis_output,
245
  fn=analyzer.analyze_image,
246
- cache_examples=False, # Disabled for CPU
247
  label="Example Images"
248
  )
249
 
 
1
+ !pip install -q transformers accelerate bitsandbytes gradio torch pillow
2
+
3
  import torch
4
+ from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig
5
  import gradio as gr
6
  from PIL import Image
7
  import re
8
  import os
9
  from typing import List, Tuple
10
 
11
+ # Configuration for 4-bit quantization (if GPU available)
12
+ quant_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_compute_dtype=torch.float16,
15
+ bnb_4bit_quant_type="nf4",
16
+ bnb_4bit_use_double_quant=True
17
+ )
18
 
19
  class RiverPollutionAnalyzer:
20
  def __init__(self):
21
  try:
22
+ # Initialize model with fallback for CPU
23
  self.processor = InstructBlipProcessor.from_pretrained(
24
  "Salesforce/instructblip-flan-t5-xl",
25
  cache_dir="model_cache"
26
  )
27
+
28
+ if torch.cuda.is_available():
29
+ self.model = InstructBlipForConditionalGeneration.from_pretrained(
30
+ "Salesforce/instructblip-flan-t5-xl",
31
+ device_map="auto",
32
+ quantization_config=quant_config,
33
+ torch_dtype=torch.float16,
34
+ cache_dir="model_cache"
35
+ )
36
+ self.device = "cuda"
37
+ self.status = "βœ… Model loaded (4-bit GPU)"
38
+ else:
39
+ self.model = InstructBlipForConditionalGeneration.from_pretrained(
40
+ "Salesforce/instructblip-flan-t5-xl",
41
+ device_map="auto",
42
+ torch_dtype=torch.float32,
43
+ cache_dir="model_cache",
44
+ low_cpu_mem_usage=True
45
+ )
46
+ self.device = "cpu"
47
+ self.status = "⚠️ Model loaded (CPU mode - slower)"
48
+
49
  except Exception as e:
50
+ self.model = None
51
  self.status = f"❌ Model loading failed: {str(e)}"
52
  print(self.status)
53
 
 
72
  }
73
 
74
  def analyze_image(self, image):
75
+ """Analyze river pollution with device-aware processing"""
76
+ if not self.model:
77
  return "Model not loaded. Please check logs."
78
 
79
+ if not isinstance(image, Image.Image):
80
+ image = Image.fromarray(image)
81
+
82
+ # Resize for efficiency
83
+ image = image.resize((512, 512))
 
 
 
 
 
84
 
85
+ prompt = """Analyze this river pollution scene and provide:
86
+ 1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
87
+ 2. Estimate pollution severity from 1-10
88
+
89
+ Respond EXACTLY in this format:
90
  Pollutants: [comma separated list]
91
  Severity: [number]"""
92
 
93
+ try:
94
  inputs = self.processor(
95
  images=image,
96
  text=prompt,
97
  return_tensors="pt"
98
+ ).to(self.model.device)
99
 
100
+ with torch.no_grad():
101
+ outputs = self.model.generate(
102
+ **inputs,
103
+ max_new_tokens=150, # Reduced for stability
104
+ temperature=0.5,
105
+ top_p=0.85,
106
+ do_sample=True
107
+ )
108
 
109
  analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
110
  pollutants, severity = self._parse_response(analysis)
111
  return self._format_analysis(pollutants, severity)
 
112
  except Exception as e:
113
  return f"⚠️ Analysis error: {str(e)}"
114
 
115
+ # [Keep all existing helper methods unchanged]
116
  def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
117
+ """Same parsing logic as before"""
118
  pollutants = []
119
  severity = 3
120
 
 
144
  return pollutants, severity
145
 
146
  def _calculate_severity(self, pollutants: List[str]) -> int:
147
+ """Same severity calculation"""
148
  if not pollutants:
149
  return 1
150
 
 
159
  return min(10, max(1, round(avg_weight * 3)))
160
 
161
  def _format_analysis(self, pollutants: List[str], severity: int) -> str:
162
+ """Same formatting"""
163
  severity_bar = f"""πŸ“Š Severity: {severity}/10
164
  {"β–ˆ" * severity}{"β–‘" * (10 - severity)}
165
  {self.severity_descriptions.get(severity, '')}"""
 
171
  {pollutants_list}
172
  {severity_bar}"""
173
 
174
+ def analyze_chat(self, message: str) -> str:
175
+ """Handle chat questions"""
176
+ if any(word in message.lower() for word in ["hello", "hi", "hey"]):
177
+ return "Hello! I'm a river pollution analyzer. Ask me about pollution types."
178
+ elif "pollution" in message.lower():
179
+ return "Common river pollutants: plastic waste, chemical foam, industrial discharge, sewage water, oil spills."
180
+ else:
181
+ return "I can answer questions about river pollution. Try asking about pollution types."
182
+
183
  # Initialize analyzer
184
  analyzer = RiverPollutionAnalyzer()
185
 
 
216
  background: #2a2a2a;
217
  border-color: #444;
218
  }
 
 
 
219
  """
220
 
221
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
 
227
  # Left Panel
228
  with gr.Column(elem_classes="left-panel"):
229
  with gr.Group():
230
+ image_input = gr.Image(type="pil", label="Upload River Image", height=300)
231
+ analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
 
 
 
 
 
 
 
232
 
233
  with gr.Group(elem_classes="analysis-box"):
234
  gr.Markdown("### πŸ“Š Analysis Report")
 
237
  # Right Panel
238
  with gr.Column(elem_classes="right-panel"):
239
  with gr.Group(elem_classes="chat-container"):
240
+ chatbot = gr.Chatbot(label="Pollution Q&A", height=400)
 
241
  with gr.Row():
242
  chat_input = gr.Textbox(
243
+ placeholder="Ask about pollution sources...",
244
+ label="Your Question",
245
  container=False,
246
  scale=5
247
  )
248
+ chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
249
+ clear_btn = gr.Button("🧹 Clear Chat", size="sm")
250
 
251
  analyze_btn.click(
252
  analyzer.analyze_image,
 
254
  outputs=analysis_output
255
  )
256
 
257
+ chat_input.submit(
258
+ lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
259
+ inputs=[chat_input, chatbot],
260
+ outputs=[chat_input, chatbot]
261
+ )
262
+
263
+ chat_btn.click(
264
+ lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]),
265
+ inputs=[chat_input, chatbot],
266
+ outputs=[chat_input, chatbot]
267
+ )
268
+
269
+ clear_btn.click(lambda: None, outputs=[chatbot])
270
+
271
  gr.Examples(
272
  examples=[
273
  ["examples/polluted_river1.jpg"],
 
276
  inputs=image_input,
277
  outputs=analysis_output,
278
  fn=analyzer.analyze_image,
279
+ cache_examples=torch.cuda.is_available(), # Cache only if GPU available
280
  label="Example Images"
281
  )
282