Atharwaaah commited on
Commit
7a48fbc
Β·
verified Β·
1 Parent(s): abb872c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -136
app.py CHANGED
@@ -1,144 +1,86 @@
1
- # app.py
2
  import torch
3
- from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
 
 
 
 
4
  import gradio as gr
5
  from PIL import Image
6
- import re
7
- from typing import List, Tuple
8
- import os
9
-
10
- # Set cache directory (important for Hugging Face Spaces)
11
- CACHE_DIR = "model_cache"
12
- os.makedirs(CACHE_DIR, exist_ok=True)
13
-
14
-
15
- class RiverPollutionAnalyzer:
16
- def __init__(self):
17
- try:
18
- # Load with caching and optimized device placement
19
- self.processor = InstructBlipProcessor.from_pretrained(
20
- "Salesforce/instructblip-vicuna-7b", cache_dir=CACHE_DIR
21
- )
22
-
23
- self.model = InstructBlipForConditionalGeneration.from_pretrained(
24
- "Salesforce/instructblip-vicuna-7b",
25
- cache_dir=CACHE_DIR,
26
- device_map="auto",
27
- torch_dtype=torch.float16,
28
- offload_folder="offload",
29
- low_cpu_mem_usage=True,
30
- )
31
-
32
- except Exception as e:
33
- raise RuntimeError(f"Model loading failed: {str(e)}")
34
-
35
- self.pollutants = [
36
- "plastic waste",
37
- "chemical foam",
38
- "industrial discharge",
39
- "sewage water",
40
- "oil spill",
41
- "organic debris",
42
- "construction waste",
43
- "medical waste",
44
- "floating trash",
45
- "algal bloom",
46
- "toxic sludge",
47
- "agricultural runoff",
48
- ]
49
-
50
- self.severity_descriptions = {
51
- 1: "Minimal pollution - Slightly noticeable",
52
- 2: "Minor pollution - Small amounts visible",
53
- 3: "Moderate pollution - Clearly visible",
54
- 4: "Significant pollution - Affecting water quality",
55
- 5: "Heavy pollution - Obvious environmental impact",
56
- 6: "Severe pollution - Large accumulation",
57
- 7: "Very severe pollution - Major ecosystem impact",
58
- 8: "Extreme pollution - Dangerous levels",
59
- 9: "Critical pollution - Immediate action needed",
60
- 10: "Disaster level - Ecological catastrophe",
61
- }
62
-
63
- def analyze_image(self, image):
64
- """Analyze river pollution with robust parsing"""
65
- if not isinstance(image, Image.Image):
66
- image = Image.fromarray(image)
67
-
68
- prompt = """Analyze this river pollution scene and provide:
69
- 1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff]
70
- 2. Estimate pollution severity from 1-10
71
-
72
- Respond EXACTLY in this format:
73
- Pollutants: [comma separated list]
74
- Severity: [number]"""
75
-
76
- inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(
77
- self.model.device
78
  )
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  with torch.no_grad():
81
- outputs = self.model.generate(
82
  **inputs,
83
- max_new_tokens=200,
84
- temperature=0.5,
85
- top_p=0.85,
86
- do_sample=True,
87
  )
88
-
89
- analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
90
- pollutants, severity = self._parse_response(analysis)
91
- return self._format_analysis(pollutants, severity)
92
-
93
- # ... [keep all your other methods unchanged] ...
94
-
95
-
96
- # Initialize analyzer with error handling
97
- try:
98
- analyzer = RiverPollutionAnalyzer()
99
- model_status = "Model loaded successfully!"
100
- except Exception as e:
101
- analyzer = None
102
- model_status = f"Model failed to load: {str(e)}"
103
- print(model_status)
104
-
105
-
106
- # Create wrapper function for Gradio
107
- def analyze_image_wrapper(image):
108
- if analyzer is None:
109
- return (
110
- f"⚠️ Error: {model_status}\nPlease try again later or use a smaller image."
111
- )
112
- return analyzer.analyze_image(image)
113
-
114
-
115
- # Gradio interface
116
- css = """
117
- /* [keep your existing CSS] */
118
- """
119
-
120
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
121
- # [keep all your existing UI code]
122
-
123
- analyze_btn.click(
124
- analyze_image_wrapper, inputs=image_input, outputs=analysis_output
125
- )
126
-
127
- # Update examples to use wrapper
128
- gr.Examples(
129
- examples=[
130
- [
131
- "https://huggingface.co/spaces/Atharwaaah/SLCR-FLOWCODE-tarak.AI/resolve/main/polluted_river1.jpg"
132
- ],
133
- [
134
- "https://huggingface.co/spaces/Atharwaaah/SLCR-FLOWCODE-tarak.AI/resolve/main/polluted_river2.jpg"
135
- ],
136
- ],
137
- inputs=image_input,
138
- outputs=analysis_output,
139
- fn=analyze_image_wrapper,
140
- cache_examples=True,
141
- label="Try example images:",
142
- )
143
-
144
- demo.launch()
 
1
+ # app.py - River Pollution Analyzer with instructblip-flan-t5-xl
2
  import torch
3
+ from transformers import (
4
+ InstructBlipProcessor,
5
+ InstructBlipForConditionalGeneration,
6
+ BitsAndBytesConfig
7
+ )
8
  import gradio as gr
9
  from PIL import Image
10
+ import logging
11
+ import functools
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ @functools.cache
18
+ def get_analyzer():
19
+ logger.info("Loading instructblip-flan-t5-xl...")
20
+ try:
21
+ # 4-bit config (works on GPU if available)
22
+ quant_config = BitsAndBytesConfig(
23
+ load_in_4bit=True,
24
+ bnb_4bit_compute_dtype=torch.float16,
25
+ bnb_4bit_quant_type="nf4",
26
+ bnb_4bit_use_double_quant=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
 
29
+ processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
30
+
31
+ model = InstructBlipForConditionalGeneration.from_pretrained(
32
+ "Salesforce/instructblip-flan-t5-xl",
33
+ quantization_config=quant_config if torch.cuda.is_available() else None,
34
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
35
+ device_map="auto"
36
+ )
37
+ return processor, model
38
+
39
+ except Exception as e:
40
+ logger.error(f"Model load failed: {str(e)}")
41
+ raise RuntimeError("Model loading error. Check logs.")
42
+
43
+ def analyze_image(image):
44
+ try:
45
+ processor, model = get_analyzer()
46
+ prompt = """Analyze river pollution. List pollutants and severity (1-10).
47
+ Respond EXACTLY like this:
48
+ Pollutants: [list]
49
+ Severity: [number]"""
50
+
51
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
52
+
53
  with torch.no_grad():
54
+ outputs = model.generate(
55
  **inputs,
56
+ max_new_tokens=100,
57
+ temperature=0.7
 
 
58
  )
59
+
60
+ result = processor.decode(outputs[0], skip_special_tokens=True)
61
+
62
+ # Format output
63
+ if "Pollutants:" in result and "Severity:" in result:
64
+ pollutants = result.split("Pollutants:")[1].split("Severity:")[0].strip()
65
+ severity = result.split("Severity:")[1].strip()
66
+ return f"""🌊 Analysis Result:
67
+ πŸ“Œ Pollutants: {pollutants}
68
+ πŸ“ˆ Severity: {severity}/10"""
69
+ return result
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error: {str(e)}")
73
+ return f"⚠️ Error (try a smaller image): {str(e)}"
74
+
75
+ # Minimal UI
76
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
77
+ gr.Markdown("# 🌍 River Pollution Analyzer (instructblip-flan-t5-xl)")
78
+ with gr.Row():
79
+ image_input = gr.Image(type="pil", label="Upload Image")
80
+ analyze_btn = gr.Button("Analyze", variant="primary")
81
+ output = gr.Textbox(label="Result")
82
+
83
+ analyze_btn.click(analyze_image, inputs=image_input, outputs=output)
84
+
85
+ if __name__ == "__main__":
86
+ demo.launch(server_name="0.0.0.0", server_port=7860)