File size: 2,958 Bytes
7a48fbc
e66820d
7a48fbc
 
 
 
 
e66820d
 
7a48fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66820d
 
7a48fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66820d
7a48fbc
e66820d
7a48fbc
 
e66820d
7a48fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# app.py - River Pollution Analyzer with instructblip-flan-t5-xl
import torch
from transformers import (
    InstructBlipProcessor,
    InstructBlipForConditionalGeneration,
    BitsAndBytesConfig
)
import gradio as gr
from PIL import Image
import logging
import functools

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@functools.cache
def get_analyzer():
    logger.info("Loading instructblip-flan-t5-xl...")
    try:
        # 4-bit config (works on GPU if available)
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )

        processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
        
        model = InstructBlipForConditionalGeneration.from_pretrained(
            "Salesforce/instructblip-flan-t5-xl",
            quantization_config=quant_config if torch.cuda.is_available() else None,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto"
        )
        return processor, model
        
    except Exception as e:
        logger.error(f"Model load failed: {str(e)}")
        raise RuntimeError("Model loading error. Check logs.")

def analyze_image(image):
    try:
        processor, model = get_analyzer()
        prompt = """Analyze river pollution. List pollutants and severity (1-10).
Respond EXACTLY like this:
Pollutants: [list]
Severity: [number]"""
        
        inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=100,
                temperature=0.7
            )
            
        result = processor.decode(outputs[0], skip_special_tokens=True)
        
        # Format output
        if "Pollutants:" in result and "Severity:" in result:
            pollutants = result.split("Pollutants:")[1].split("Severity:")[0].strip()
            severity = result.split("Severity:")[1].strip()
            return f"""🌊 Analysis Result:
πŸ“Œ Pollutants: {pollutants}
πŸ“ˆ Severity: {severity}/10"""
        return result
        
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        return f"⚠️ Error (try a smaller image): {str(e)}"

# Minimal UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🌍 River Pollution Analyzer (instructblip-flan-t5-xl)")
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Image")
        analyze_btn = gr.Button("Analyze", variant="primary")
    output = gr.Textbox(label="Result")
    
    analyze_btn.click(analyze_image, inputs=image_input, outputs=output)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)