import torch from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, BitsAndBytesConfig import gradio as gr from PIL import Image import re from typing import List, Tuple # Configuration for 4-bit quantization quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) class RiverPollutionAnalyzer: def __init__(self): try: # Initialize InstructBLIP-FLAN-T5-XL with 4-bit quantization self.processor = InstructBlipProcessor.from_pretrained( "Salesforce/instructblip-flan-t5-xl", cache_dir="model_cache" ) self.model = InstructBlipForConditionalGeneration.from_pretrained( "Salesforce/instructblip-flan-t5-xl", quantization_config=quant_config, device_map="auto", torch_dtype=torch.float16, cache_dir="model_cache" ) except Exception as e: raise RuntimeError(f"Model loading failed: {str(e)}") self.pollutants = [ "plastic waste", "chemical foam", "industrial discharge", "sewage water", "oil spill", "organic debris", "construction waste", "medical waste", "floating trash", "algal bloom", "toxic sludge", "agricultural runoff" ] self.severity_descriptions = { 1: "Minimal pollution - Slightly noticeable", 2: "Minor pollution - Small amounts visible", 3: "Moderate pollution - Clearly visible", 4: "Significant pollution - Affecting water quality", 5: "Heavy pollution - Obvious environmental impact", 6: "Severe pollution - Large accumulation", 7: "Very severe pollution - Major ecosystem impact", 8: "Extreme pollution - Dangerous levels", 9: "Critical pollution - Immediate action needed", 10: "Disaster level - Ecological catastrophe" } def analyze_image(self, image): """Analyze river pollution with robust parsing""" if not isinstance(image, Image.Image): image = Image.fromarray(image) prompt = """Analyze this river pollution scene and provide: 1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff] 2. Estimate pollution severity from 1-10 Respond EXACTLY in this format: Pollutants: [comma separated list] Severity: [number]""" try: inputs = self.processor( images=image, text=prompt, return_tensors="pt" ).to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=200, temperature=0.5, top_p=0.85, do_sample=True ) analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] pollutants, severity = self._parse_response(analysis) return self._format_analysis(pollutants, severity) except Exception as e: return f"โš ๏ธ Analysis failed: {str(e)}" # [Keep your existing parsing/formatting methods] def _parse_response(self, analysis: str) -> Tuple[List[str], int]: pollutants = [] severity = 3 # Extract pollutants pollutant_match = re.search( r'(?i)(pollutants?|contaminants?)[:\s]*\[?(.*?)(?:\]|Severity|severity|$)', analysis ) if pollutant_match: pollutants_str = pollutant_match.group(2).strip() pollutants = [ p.strip().lower() for p in re.split(r'[,;]|\band\b', pollutants_str) if p.strip().lower() in self.pollutants ] # Extract severity severity_match = re.search( r'(?i)(severity|level)[:\s]*(\d{1,2})', analysis ) if severity_match: try: severity = min(max(int(severity_match.group(2)), 1), 10) except: severity = self._calculate_severity(pollutants) else: severity = self._calculate_severity(pollutants) return pollutants, severity def _calculate_severity(self, pollutants: List[str]) -> int: if not pollutants: return 1 weights = { "medical waste": 3, "toxic sludge": 3, "oil spill": 2.5, "chemical foam": 2, "industrial discharge": 2, "sewage water": 2, "plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5, "agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1 } avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants) return min(10, max(1, round(avg_weight * 3))) def _format_analysis(self, pollutants: List[str], severity: int) -> str: severity_bar = f"""๐Ÿ“Š Severity: {severity}/10 {"โ–ˆ" * severity}{"โ–‘" * (10 - severity)} {self.severity_descriptions.get(severity, '')}""" pollutants_list = "\n๐Ÿ” No pollutants detected" if not pollutants else "\n".join( f"{i}. {p.capitalize()}" for i, p in enumerate(pollutants[:5], 1)) return f"""๐ŸŒŠ River Pollution Analysis ๐ŸŒŠ {pollutants_list} {severity_bar}""" def analyze_chat(self, message: str) -> str: if any(word in message.lower() for word in ["hello", "hi", "hey"]): return "Hello! I'm a river pollution analyzer. Ask me about pollution types or upload an image for analysis." elif "pollution" in message.lower(): return "Common river pollutants include: plastic waste, chemical foam, industrial discharge, sewage water, and oil spills." else: return "I can answer questions about river pollution. Try asking about pollution types or upload an image for analysis." # Initialize with error handling try: analyzer = RiverPollutionAnalyzer() model_status = "โœ… Model loaded successfully" except Exception as e: analyzer = None model_status = f"โŒ Model loading failed: {str(e)}" css = """ .header { text-align: center; padding: 20px; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; margin-bottom: 20px; } .side-by-side { display: flex; gap: 20px; } .left-panel, .right-panel { flex: 1; } .analysis-box { padding: 20px; background: #f8f9fa; border-radius: 10px; margin-top: 20px; border: 1px solid #dee2e6; } .chat-container { background: #f8f9fa; padding: 20px; border-radius: 10px; height: 100%; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: with gr.Column(elem_classes="header"): gr.Markdown("# ๐ŸŒ River Pollution Analyzer") gr.Markdown(f"### {model_status}") with gr.Row(elem_classes="side-by-side"): with gr.Column(elem_classes="left-panel"): with gr.Group(): image_input = gr.Image(type="pil", label="Upload River Image", height=300) analyze_btn = gr.Button("๐Ÿ” Analyze Pollution", variant="primary") with gr.Group(elem_classes="analysis-box"): gr.Markdown("### ๐Ÿ“Š Analysis report") analysis_output = gr.Markdown() with gr.Column(elem_classes="right-panel"): with gr.Group(elem_classes="chat-container"): chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400) with gr.Row(): chat_input = gr.Textbox( placeholder="Ask about pollution sources...", label="Your Question", container=False, scale=5 ) chat_btn = gr.Button("๐Ÿ’ฌ Ask", variant="secondary", scale=1) clear_btn = gr.Button("๐Ÿงน Clear Chat History", size="sm") analyze_btn.click( analyzer.analyze_image if analyzer else lambda x: "Model not loaded", inputs=image_input, outputs=analysis_output ) chat_input.submit( lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]), inputs=[chat_input, chatbot], outputs=[chat_input, chatbot] ) chat_btn.click( lambda msg, chat: ("", chat + [(msg, analyzer.analyze_chat(msg))]), inputs=[chat_input, chatbot], outputs=[chat_input, chatbot] ) clear_btn.click(lambda: None, outputs=[chatbot]) gr.Examples( examples=[ ["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river1.jpg"], ["https://huggingface.co/spaces/atharwaah1work/tarak.AI/resolve/main/polluted_river2.jpg"] ], inputs=image_input, outputs=analysis_output, fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded", cache_examples=True, label="Try example images:" ) demo.queue(max_size=3).launch()