tarak.AI / app.py
Atharwaaah's picture
Update app.py
96d6b6d verified
raw
history blame
10.9 kB
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
BlipProcessor,
BlipForConditionalGeneration,
BitsAndBytesConfig
)
import gradio as gr
from PIL import Image
import re
import os
from typing import List, Tuple
# Create cache directory
os.makedirs("model_cache", exist_ok=True)
os.makedirs("examples", exist_ok=True) # Create examples directory
# Configuration for 4-bit quantization
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
class RiverPollutionAnalyzer:
def __init__(self):
try:
# Initialize BLIP for image captioning with caching
self.blip_processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base",
cache_dir="model_cache"
)
self.blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
torch_dtype=torch.float16,
device_map="auto",
cache_dir="model_cache"
)
# Initialize FLAN-T5-XL with quantization
self.tokenizer = AutoTokenizer.from_pretrained(
"google/flan-t5-xl",
cache_dir="model_cache"
)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
"google/flan-t5-xl",
device_map="auto",
quantization_config=quant_config,
cache_dir="model_cache"
)
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
self.pollutants = [
"plastic waste", "chemical foam", "industrial discharge",
"sewage water", "oil spill", "organic debris",
"construction waste", "medical waste", "floating trash",
"algal bloom", "toxic sludge", "agricultural runoff"
]
self.severity_descriptions = {
1: "Minimal pollution - Slightly noticeable",
2: "Minor pollution - Small amounts visible",
3: "Moderate pollution - Clearly visible",
4: "Significant pollution - Affecting water quality",
5: "Heavy pollution - Obvious environmental impact",
6: "Severe pollution - Large accumulation",
7: "Very severe pollution - Major ecosystem impact",
8: "Extreme pollution - Dangerous levels",
9: "Critical pollution - Immediate action needed",
10: "Disaster level - Ecological catastrophe"
}
def analyze_image(self, image):
"""Two-step analysis: BLIP captioning + FLAN-T5 analysis"""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
try:
# Step 1: Generate image caption with BLIP
inputs = self.blip_processor(image, return_tensors="pt").to(self.blip_model.device, torch.float16)
caption = self.blip_model.generate(**inputs, max_new_tokens=100)[0]
caption = self.blip_processor.decode(caption, skip_special_tokens=True)
# Step 2: Analyze caption with FLAN-T5
prompt = f"""Analyze this river scene: '{caption}'
1. List visible pollutants from: {self.pollutants}
2. Estimate severity (1-10)
Respond EXACTLY as:
Pollutants: [comma separated list]
Severity: [number]"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=200)
analysis = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
pollutants, severity = self._parse_response(analysis)
return self._format_analysis(pollutants, severity)
except Exception as e:
return f"⚠️ Analysis failed: {str(e)}"
def _parse_response(self, analysis: str) -> Tuple[List[str], int]:
"""Parse the model response into pollutants list and severity score"""
pollutants = []
severity = 0
# Extract pollutants
pollutants_match = re.search(r"Pollutants:\s*\[(.*?)\]", analysis)
if pollutants_match:
pollutants_str = pollutants_match.group(1)
pollutants = [p.strip() for p in pollutants_str.split(",") if p.strip()]
# Extract severity
severity_match = re.search(r"Severity:\s*(\d+)", analysis)
if severity_match:
severity = int(severity_match.group(1))
# If parsing failed, fallback to calculating severity
if not severity or severity < 1 or severity > 10:
severity = self._calculate_severity(pollutants)
return pollutants, severity
def _calculate_severity(self, pollutants: List[str]) -> int:
"""Calculate severity based on pollutants"""
if not pollutants:
return 1
severity_map = {
"plastic waste": 4,
"chemical foam": 7,
"industrial discharge": 8,
"sewage water": 6,
"oil spill": 9,
"organic debris": 3,
"construction waste": 5,
"medical waste": 8,
"floating trash": 4,
"algal bloom": 6,
"toxic sludge": 9,
"agricultural runoff": 5
}
base_score = sum(severity_map.get(p, 3) for p in pollutants)
avg_score = base_score / len(pollutants)
return min(10, max(1, round(avg_score)))
def _format_analysis(self, pollutants: List[str], severity: int) -> str:
"""Format the analysis results into a markdown report"""
if not pollutants:
pollutants = ["No visible pollution detected"]
pollutants_list = "\n".join(f"- {p}" for p in pollutants)
severity_desc = self.severity_descriptions.get(severity, "Unknown severity level")
return f"""
## Pollution Analysis Report
### Identified Pollutants:
{pollutants_list}
### Severity Assessment:
**Level {severity}/10** - {severity_desc}
### Recommended Actions:
{self._get_recommendations(severity)}
"""
def _get_recommendations(self, severity: int) -> str:
"""Get recommendations based on severity level"""
if severity <= 3:
return "Monitor the situation. Consider community clean-up efforts."
elif severity <= 5:
return "Local authorities should investigate. Basic remediation needed."
elif severity <= 7:
return "Immediate containment required. Environmental assessment needed."
elif severity <= 9:
return "Emergency response required. Notify environmental agencies."
else:
return "Disaster response needed. Evacuation may be necessary."
def analyze_chat(self, message: str) -> str:
"""Handle chat questions about pollution"""
prompt = f"""You are an environmental expert. Answer this question about river pollution: {message}
Provide a concise, factual response in under 100 words."""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=150)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Initialize with error handling
try:
analyzer = RiverPollutionAnalyzer()
model_status = "βœ… Models loaded successfully"
except Exception as e:
analyzer = None
model_status = f"❌ Model loading failed: {str(e)}"
# Gradio Interface
css = """
.header {
text-align: center;
max-width: 800px;
margin: auto;
}
.header img {
max-width: 100%;
}
.side-by-side {
display: flex;
flex-wrap: wrap;
gap: 20px;
}
.left-panel, .right-panel {
flex: 1;
min-width: 300px;
}
.analysis-box {
border: 1px solid #e0e0e0;
border-radius: 8px;
padding: 15px;
margin-top: 15px;
background: #f9f9f9;
}
.chat-container {
border: 1px solid #e0e0e0;
border-radius: 8px;
padding: 15px;
background: #f9f9f9;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
with gr.Column(elem_classes="header"):
gr.Markdown("# 🌍 River Pollution Analyzer")
gr.Markdown(f"### {model_status}")
with gr.Row(elem_classes="side-by-side"):
# Left Panel
with gr.Column(elem_classes="left-panel"):
with gr.Group():
image_input = gr.Image(type="pil", label="Upload River Image", height=300)
analyze_btn = gr.Button("πŸ” Analyze Pollution", variant="primary")
with gr.Group(elem_classes="analysis-box"):
gr.Markdown("### πŸ“Š Analysis report")
analysis_output = gr.Markdown()
# Right Panel
with gr.Column(elem_classes="right-panel"):
with gr.Group(elem_classes="chat-container"):
chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400)
with gr.Row():
chat_input = gr.Textbox(
placeholder="Ask about pollution sources...",
label="Your Question",
container=False,
scale=5
)
chat_btn = gr.Button("πŸ’¬ Ask", variant="secondary", scale=1)
clear_btn = gr.Button("🧹 Clear Chat History", size="sm")
# Connect functions
analyze_btn.click(
analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
inputs=image_input,
outputs=analysis_output
)
def respond(message, chat_history):
if not analyzer:
return chat_history + [(message, "Models not loaded. Please try again later.")]
response = analyzer.analyze_chat(message)
return chat_history + [(message, response)]
chat_btn.click(
respond,
[chat_input, chatbot],
[chatbot],
)
chat_input.submit(
respond,
[chat_input, chatbot],
[chatbot],
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
# Update examples to use local files
gr.Examples(
examples=[
["examples/polluted_river1.jpg"],
["examples/polluted_river2.jpg"]
],
inputs=image_input,
outputs=analysis_output,
fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded",
cache_examples=True,
label="Try example images:"
)
# Launch with queue for stability and allowed paths
demo.queue(max_size=3).launch(allowed_paths=["examples"])