Spaces:
Runtime error
Runtime error
import torch | |
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig | |
import gradio as gr | |
from PIL import Image | |
import re | |
from typing import List, Tuple | |
# Configuration for 4-bit quantization | |
quant_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True | |
) | |
class RiverPollutionAnalyzer: | |
def __init__(self): | |
try: | |
# Initialize InstructBLIP-FLAN-T5-XL with 4-bit quantization | |
self.processor = InstructBlipProcessor.from_pretrained( | |
"Salesforce/instructblip-flan-t5-xl", | |
cache_dir="model_cache" | |
) | |
self.model = InstructBlipForConditionalGeneration.from_pretrained( | |
"Salesforce/instructblip-flan-t5-xl", | |
quantization_config=quant_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
cache_dir="model_cache" | |
) | |
except Exception as e: | |
raise RuntimeError(f"Model loading failed: {str(e)}") | |
self.pollutants = [ | |
"plastic waste", "chemical foam", "industrial discharge", | |
"sewage water", "oil spill", "organic debris", | |
"construction waste", "medical waste", "floating trash", | |
"algal bloom", "toxic sludge", "agricultural runoff" | |
] | |
self.severity_descriptions = { | |
1: "Minimal pollution - Slightly noticeable", | |
2: "Minor pollution - Small amounts visible", | |
3: "Moderate pollution - Clearly visible", | |
4: "Significant pollution - Affecting water quality", | |
5: "Heavy pollution - Obvious environmental impact", | |
6: "Severe pollution - Large accumulation", | |
7: "Very severe pollution - Major ecosystem impact", | |
8: "Extreme pollution - Dangerous levels", | |
9: "Critical pollution - Immediate action needed", | |
10: "Disaster level - Ecological catastrophe" | |
} | |
def analyze_image(self, image): | |
"""Analyze river pollution with robust parsing""" | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
prompt = """Analyze this river pollution scene and provide: | |
1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff] | |
2. Estimate pollution severity from 1-10 | |
Respond EXACTLY in this format: | |
Pollutants: [comma separated list] | |
Severity: [number]""" | |
try: | |
inputs = self.processor( | |
images=image, | |
text=prompt, | |
return_tensors="pt" | |
).to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=200, | |
temperature=0.5, | |
top_p=0.85, | |
do_sample=True | |
) | |
analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
pollutants, severity = self._parse_response(analysis) | |
return self._format_analysis(pollutants, severity) | |
except Exception as e: | |
return f"β οΈ Analysis failed: {str(e)}" | |
# [Keep your existing parsing/formatting methods] | |
def _parse_response(self, analysis: str) -> Tuple[List[str], int]: | |
pollutants = [] | |
severity = 3 | |
# Extract pollutants | |
pollutant_match = re.search( | |
r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)', | |
analysis | |
) | |
if pollutant_match: | |
pollutants_str = pollutant_match.group(2).strip() | |
pollutants = [ | |
p.strip().lower() | |
for p in re.split(r'[,;]|\band\b', pollutants_str) | |
if p.strip().lower() in self.pollutants | |
] | |
# Extract severity | |
severity_match = re.search( | |
r'(?i)(severity|level)[:\s]*(\d{1,2})', | |
analysis | |
) | |
if severity_match: | |
try: | |
severity = min(max(int(severity_match.group(2)), 1), 10) | |
except: | |
severity = self._calculate_severity(pollutants) | |
else: | |
severity = self._calculate_severity(pollutants) | |
return pollutants, severity | |
def _calculate_severity(self, pollutants: List[str]) -> int: | |
if not pollutants: | |
return 1 | |
weights = { | |
"medical waste": 3, "toxic sludge": 3, "oil spill": 2.5, | |
"chemical foam": 2, "industrial discharge": 2, "sewage water": 2, | |
"plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5, | |
"agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1 | |
} | |
avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants) | |
return min(10, max(1, round(avg_weight * 3))) | |
def _format_analysis(self, pollutants: List[str], severity: int) -> str: | |
severity_bar = f"""π Severity: {severity}/10 | |
{"β" * severity}{"β" * (10 - severity)} | |
{self.severity_descriptions.get(severity, '')}""" | |
pollutants_list = "\nπ No pollutants detected" if not pollutants else "\n".join( | |
f"{i}. {p.capitalize()}" for i, p in enumerate(pollutants[:5], 1)) | |
return f"""π River Pollution Analysis π | |
{pollutants_list} | |
{severity_bar}""" | |
def analyze_chat(self, message: str) -> str: | |
if any(word in message.lower() for word in ["hello", "hi", "hey"]): | |
return "Hello! I'm a river pollution analyzer. Ask me about pollution types or upload an image for analysis." | |
elif "pollution" in message.lower(): | |
return "Common river pollutants include: plastic waste, chemical foam, industrial discharge, sewage water, and oil spills." | |
else: | |
return "I can answer questions about river pollution. Try asking about pollution types or upload an image for analysis." | |
# Initialize with error handling | |
try: | |
analyzer = RiverPollutionAnalyzer() | |
model_status = "β Model loaded successfully" | |
except Exception as e: | |
analyzer = None | |
model_status = f"β Model loading failed: {str(e)}" | |
css = """ | |
.header { | |
text-align: center; | |
padding: 20px; | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
border-radius: 10px; | |
margin-bottom: 20px; | |
} | |
.side-by-side { | |
display: flex; | |
gap: 20px; | |
} | |
.left-panel, .right-panel { | |
flex: 1; | |
} | |
.analysis-box { | |
padding: 20px; | |
background: #f8f9fa; | |
border-radius: 10px; | |
margin-top: 20px; | |
border: 1px solid #dee2e6; | |
} | |
.chat-container { | |
background: #f8f9fa; | |
padding: 20px; | |
border-radius: 10px; | |
height: 100%; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
with gr.Column(elem_classes="header"): | |
gr.Markdown("# π River Pollution Analyzer") | |
gr.Markdown(f"### {model_status}") | |
with gr.Row(elem_classes="side-by-side"): | |
with gr.Column(elem_classes="left-panel"): | |
with gr.Group(): | |
image_input = gr.Image(type="pil", label="Upload River Image", height=300) | |
analyze_btn = gr.Button("π Analyze Pollution", variant="primary") | |
with gr.Group(elem_classes="analysis-box"): | |
gr.Markdown("### π Analysis report") | |
analysis_output = gr.Markdown() | |
with gr.Column(elem_classes="right-panel"): | |
with gr.Group(elem_classes="chat-container"): | |
chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400) | |
with gr.Row(): | |
chat_input = gr.Textbox( | |
placeholder="Ask about pollution sources...", | |
label="Your Question", | |
container=False, | |
scale=5 | |
) | |
chat_btn = gr.Button("π¬ Ask", variant="secondary", scale=1) | |
clear_btn = gr.Button("π§Ή Clear Chat History", size="sm") | |
analyze_btn.click( | |
analyzer.analyze_image if analyzer else lambda x: "Model not loaded", | |
inputs=image_input, | |
outputs=analysis_output | |
) | |
chat_input.submit( | |
lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]), | |
inputs=[chat_input, chatbot], | |
outputs=[chat_input, chatbot] | |
) | |
chat_btn.click( | |
lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]), | |
inputs=[chat_input, chatbot], | |
outputs=[chat_input, chatbot] | |
) | |
clear_btn.click(lambda: None, outputs=[chatbot]) | |
gr.Examples( | |
examples=[ | |
["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river1.jpg"], | |
["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river2.jpg"] | |
], | |
inputs=image_input, | |
outputs=analysis_output, | |
fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded", | |
cache_examples=True, | |
label="Try example images:" | |
) | |
demo.queue(max_size=3).launch() | |