Atharwaaah commited on
Commit
b9c078d
Β·
verified Β·
1 Parent(s): ec516f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -19
app.py CHANGED
@@ -1,24 +1,179 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
- from transformers import BlipProcessor, BlipForConditionalGeneration
4
- import torch
5
 
6
- # Load a smaller, GPU-compatible model
7
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
9
-
10
- def generate_caption(image):
11
- inputs = processor(images=image, return_tensors="pt")
12
- out = model.generate(**inputs)
13
- caption = processor.decode(out[0], skip_special_tokens=True)
14
- return caption
15
-
16
- demo = gr.Interface(
17
- fn=generate_caption,
18
- inputs=gr.Image(type="pil"),
19
- outputs="text",
20
- title="Image Caption Generator",
21
- description="Upload an image and get a caption using BLIP base model."
22
  )
23
 
24
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install -q transformers accelerate bitsandbytes gradio torch pillow
2
+
3
+ import torch
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSeq2SeqLM,
7
+ BlipProcessor,
8
+ BlipForConditionalGeneration,
9
+ BitsAndBytesConfig
10
+ )
11
  import gradio as gr
12
  from PIL import Image
13
+ import re
14
+ from typing import List, Tuple
15
 
16
+ # Configuration for 4-bit quantization
17
+ quant_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_compute_dtype=torch.float16,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_use_double_quant=True
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
+ class RiverPollutionAnalyzer:
25
+ def __init__(self):
26
+ try:
27
+ # Initialize BLIP for image captioning
28
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
29
+ self.blip_model = BlipForConditionalGeneration.from_pretrained(
30
+ "Salesforce/blip-image-captioning-base",
31
+ torch_dtype=torch.float16,
32
+ device_map="auto"
33
+ )
34
+
35
+ # Initialize FLAN-T5-XL for text analysis
36
+ self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
37
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
38
+ "google/flan-t5-xl",
39
+ device_map="auto",
40
+ quantization_config=quant_config
41
+ )
42
+
43
+ except Exception as e:
44
+ raise RuntimeError(f"Model loading failed: {str(e)}")
45
+
46
+ self.pollutants = [
47
+ "plastic waste", "chemical foam", "industrial discharge",
48
+ "sewage water", "oil spill", "organic debris",
49
+ "construction waste", "medical waste", "floating trash",
50
+ "algal bloom", "toxic sludge", "agricultural runoff"
51
+ ]
52
+
53
+ self.severity_descriptions = {
54
+ 1: "Minimal pollution - Slightly noticeable",
55
+ 2: "Minor pollution - Small amounts visible",
56
+ 3: "Moderate pollution - Clearly visible",
57
+ 4: "Significant pollution - Affecting water quality",
58
+ 5: "Heavy pollution - Obvious environmental impact",
59
+ 6: "Severe pollution - Large accumulation",
60
+ 7: "Very severe pollution - Major ecosystem impact",
61
+ 8: "Extreme pollution - Dangerous levels",
62
+ 9: "Critical pollution - Immediate action needed",
63
+ 10: "Disaster level - Ecological catastrophe"
64
+ }
65
+
66
+ def analyze_image(self, image):
67
+ """Two-step analysis: BLIP captioning + FLAN-T5 analysis"""
68
+ if not isinstance(image, Image.Image):
69
+ image = Image.fromarray(image)
70
+
71
+ try:
72
+ # Step 1: Generate image caption with BLIP
73
+ inputs = self.blip_processor(image, return_tensors="pt").to(self.blip_model.device, torch.float16)
74
+ caption = self.blip_model.generate(**inputs, max_new_tokens=100)[0]
75
+ caption = self.blip_processor.decode(caption, skip_special_tokens=True)
76
+
77
+ # Step 2: Analyze caption with FLAN-T5
78
+ prompt = f"""Analyze this river scene: '{caption}'
79
+ 1. List visible pollutants from: {self.pollutants}
80
+ 2. Estimate severity (1-10)
81
+
82
+ Respond EXACTLY as:
83
+ Pollutants: [comma separated list]
84
+ Severity: [number]"""
85
+
86
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
87
+ outputs = self.model.generate(**inputs, max_new_tokens=200)
88
+ analysis = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
89
+
90
+ pollutants, severity = self._parse_response(analysis)
91
+ return self._format_analysis(pollutants, severity)
92
+
93
+ except Exception as e:
94
+ return f"⚠️ Analysis failed: {str(e)}"
95
+
96
+ # [Keep all your existing parsing/formatting methods unchanged]
97
+ def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
98
+ """Same parsing logic as before"""
99
+ # ... (unchanged from your original code) ...
100
+
101
+ def _calculate_severity(self, pollutants: List[str]) -> int:
102
+ """Same severity calculation"""
103
+ # ... (unchanged from your original code) ...
104
+
105
+ def _format_analysis(self, pollutants: List[str], severity: int) -> str:
106
+ """Same formatting"""
107
+ # ... (unchanged from your original code) ...
108
+
109
+ def analyze_chat(self, message: str) -> str:
110
+ """Same chat handler"""
111
+ # ... (unchanged from your original code) ...
112
+
113
+ # Initialize with error handling
114
+ try:
115
+ analyzer = RiverPollutionAnalyzer()
116
+ model_status = "βœ… Models loaded successfully"
117
+ except Exception as e:
118
+ analyzer = None
119
+ model_status = f"❌ Model loading failed: {str(e)}"
120
+
121
+ # Gradio Interface (unchanged layout from your original)
122
+ css = """
123
+ /* [Keep your existing CSS] */
124
+ """
125
+
126
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
127
+ with gr.Column(elem_classes="header"):
128
+ gr.Markdown("# 🌍 River Pollution Analyzer")
129
+ gr.Markdown(f"### {model_status}")
130
+
131
+ with gr.Row(elem_classes="side-by-side"):
132
+ # Left Panel
133
+ with gr.Column(elem_classes="left-panel"):
134
+ with gr.Group():
135
+ image_input = gr.Image(type="pil", label="Upload River Image", height=300)
136
+ analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
137
+
138
+ with gr.Group(elem_classes="analysis-box"):
139
+ gr.Markdown("### πŸ“Š Analysis report")
140
+ analysis_output = gr.Markdown()
141
+
142
+ # Right Panel
143
+ with gr.Column(elem_classes="right-panel"):
144
+ with gr.Group(elem_classes="chat-container"):
145
+ chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400)
146
+ with gr.Row():
147
+ chat_input = gr.Textbox(
148
+ placeholder="Ask about pollution sources...",
149
+ label="Your Question",
150
+ container=False,
151
+ scale=5
152
+ )
153
+ chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
154
+ clear_btn = gr.Button("🧹 Clear Chat History", size="sm")
155
+
156
+ # Connect functions
157
+ analyze_btn.click(
158
+ analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
159
+ inputs=image_input,
160
+ outputs=analysis_output
161
+ )
162
+
163
+ # [Keep all other UI event handlers unchanged]
164
+
165
+ # Update examples to use local files
166
+ gr.Examples(
167
+ examples=[
168
+ ["examples/polluted_river1.jpg"],
169
+ ["examples/polluted_river2.jpg"]
170
+ ],
171
+ inputs=image_input,
172
+ outputs=analysis_output,
173
+ fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
174
+ cache_examples=True,
175
+ label="Try example images:"
176
+ )
177
+
178
+ # Launch with queue for stability
179
+ demo.queue(max_size=3).launch()