import streamlit as st
import torch
from PIL import Image
import numpy as np
from transformers import ViTFeatureExtractor, ViTForImageClassification
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
import logging
import faiss
from typing import List, Dict
from datetime import datetime
from groq import Groq
import os
from functools import lru_cache

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class RAGSystem:
    def __init__(self):
        # Load models only when needed
        self._embedding_model = None
        self._vector_store = None
        self._knowledge_base = None

    @property
    def embedding_model(self):
        if self._embedding_model is None:
            self._embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        return self._embedding_model

    @property
    def knowledge_base(self):
        if self._knowledge_base is None:
            self._knowledge_base = self.load_knowledge_base()
        return self._knowledge_base

    @property
    def vector_store(self):
        if self._vector_store is None:
            self._vector_store = self.create_vector_store()
        return self._vector_store

    @staticmethod
    @lru_cache(maxsize=1)  # Cache the knowledge base
    def load_knowledge_base() -> List[Dict]:
        """Load and preprocess knowledge base"""
        kb = {
            "spalling": [
                {
                    "severity": "Critical",
                    "description": "Severe concrete spalling with exposed reinforcement",
                    "repair_method": "Remove deteriorated concrete, clean reinforcement",
                    "immediate_action": "Evacuate area, install support",
                    "prevention": "Regular inspections, waterproofing"
                }
            ],
            "structural_cracks": [
                {
                    "severity": "High",
                    "description": "Active structural cracks >5mm width",
                    "repair_method": "Structural analysis, epoxy injection",
                    "immediate_action": "Install crack monitors",
                    "prevention": "Regular monitoring, load management"
                }
            ],
            "surface_deterioration": [
                {
                    "severity": "Medium",
                    "description": "Surface scaling and deterioration",
                    "repair_method": "Surface preparation, patch repair",
                    "immediate_action": "Document extent, plan repairs",
                    "prevention": "Surface sealers, proper drainage"
                }
            ],
            "corrosion": [
                {
                    "severity": "High",
                    "description": "Corrosion of reinforcement leading to cracks",
                    "repair_method": "Remove rust, apply inhibitors",
                    "immediate_action": "Isolate affected area",
                    "prevention": "Anti-corrosion coatings, proper drainage"
                }
            ],
            "efflorescence": [
                {
                    "severity": "Low",
                    "description": "White powder deposits on concrete surfaces",
                    "repair_method": "Surface cleaning, sealant application",
                    "immediate_action": "Identify moisture source",
                    "prevention": "Improve waterproofing, reduce moisture ingress"
                }
            ],
            "delamination": [
                {
                    "severity": "Medium",
                    "description": "Separation of layers in concrete",
                    "repair_method": "Resurface or replace delaminated sections",
                    "immediate_action": "Inspect bonding layers",
                    "prevention": "Proper curing and bonding agents"
                }
            ],
            "honeycombing": [
                {
                    "severity": "Medium",
                    "description": "Voids in concrete caused by improper compaction",
                    "repair_method": "Grout injection, patch repair",
                    "immediate_action": "Assess structural impact",
                    "prevention": "Proper vibration during pouring"
                }
            ],
            "water_leakage": [
                {
                    "severity": "High",
                    "description": "Water ingress through cracks or joints",
                    "repair_method": "Injection grouting, waterproofing membranes",
                    "immediate_action": "Stop water flow, apply sealants",
                    "prevention": "Drainage systems, joint sealing"
                }
            ],
            "settlement_cracks": [
                {
                    "severity": "High",
                    "description": "Cracks due to uneven foundation settlement",
                    "repair_method": "Foundation underpinning, grouting",
                    "immediate_action": "Monitor movement, stabilize foundation",
                    "prevention": "Soil compaction, proper foundation design"
                }
            ],
            "shrinkage_cracks": [
                {
                    "severity": "Low",
                    "description": "Minor cracks caused by shrinkage during curing",
                    "repair_method": "Sealant application, surface repairs",
                    "immediate_action": "Monitor cracks",
                    "prevention": "Proper curing and moisture control"
                }
            ]
        }

        documents = []
        for category, items in kb.items():
            for item in items:
                doc_text = f"Category: {category}\n"
                for key, value in item.items():
                    doc_text += f"{key}: {value}\n"
                documents.append({"text": doc_text, "metadata": {"category": category}})

        return documents

    def create_vector_store(self):
        """Create FAISS vector store"""
        texts = [doc["text"] for doc in self.knowledge_base]
        embeddings = self.embedding_model.encode(texts)
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(np.array(embeddings).astype('float32'))
        return index

    @lru_cache(maxsize=32)  # Cache recent query results
    def get_relevant_context(self, query: str, k: int = 2) -> str:
        """Retrieve relevant context based on query"""
        try:
            query_embedding = self.embedding_model.encode([query])
            D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
            context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
            return context
        except Exception as e:
            logger.error(f"Error retrieving context: {e}")
            return ""

class ImageAnalyzer:
    def __init__(self, model_name="microsoft/swin-base-patch4-window7-224-in22k"):
        self.device = "cpu"
        self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
        self.model_name = model_name
        self._model = None
        self._feature_extractor = None

    @property
    def model(self):
        if self._model is None:
            self._model = self._load_model()
        return self._model

    @property
    def feature_extractor(self):
        if self._feature_extractor is None:
            self._feature_extractor = self._load_feature_extractor()
        return self._feature_extractor

    def _load_feature_extractor(self):
        """Load the appropriate feature extractor based on model type"""
        try:
            if "swin" in self.model_name:
                from transformers import AutoFeatureExtractor
                return AutoFeatureExtractor.from_pretrained(self.model_name)
            elif "convnext" in self.model_name:
                from transformers import ConvNextFeatureExtractor
                return ConvNextFeatureExtractor.from_pretrained(self.model_name)
            else:
                from transformers import ViTFeatureExtractor
                return ViTFeatureExtractor.from_pretrained(self.model_name)
        except Exception as e:
            logger.error(f"Feature extractor initialization error: {e}")
            return None

    def _load_model(self):
        try:
            if "swin" in self.model_name:
                from transformers import SwinForImageClassification
                model = SwinForImageClassification.from_pretrained(
                    self.model_name,
                    num_labels=len(self.defect_classes),
                    ignore_mismatched_sizes=True
                )
            elif "convnext" in self.model_name:
                from transformers import ConvNextForImageClassification
                model = ConvNextForImageClassification.from_pretrained(
                    self.model_name,
                    num_labels=len(self.defect_classes),
                    ignore_mismatched_sizes=True
                )
            else:
                from transformers import ViTForImageClassification
                model = ViTForImageClassification.from_pretrained(
                    self.model_name,
                    num_labels=len(self.defect_classes),
                    ignore_mismatched_sizes=True
                )

            model = model.to(self.device)
            
            # Reinitialize the classifier layer
            with torch.no_grad():
                if hasattr(model, 'classifier'):
                    in_features = model.classifier.in_features
                    model.classifier = torch.nn.Linear(in_features, len(self.defect_classes))
                elif hasattr(model, 'head'):
                    in_features = model.head.in_features
                    model.head = torch.nn.Linear(in_features, len(self.defect_classes))
                
            return model
        except Exception as e:
            logger.error(f"Model initialization error: {e}")
            return None

    def preprocess_image(self, image_bytes):
        """Preprocess image for model input"""
        return _cached_preprocess_image(image_bytes, self.model_name)

    def analyze_image(self, image):
        """Analyze image for defects"""
        try:
            if self.model is None:
                raise ValueError("Model not properly initialized")

            inputs = self.feature_extractor(
                images=image,
                return_tensors="pt"
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
            
            confidence_threshold = 0.3
            results = {
                self.defect_classes[i]: float(probs[i]) 
                for i in range(len(self.defect_classes))
                if float(probs[i]) > confidence_threshold
            }
            
            if not results:
                max_idx = torch.argmax(probs)
                results = {self.defect_classes[int(max_idx)]: float(probs[max_idx])}
            
            return results
            
        except Exception as e:
            logger.error(f"Analysis error: {str(e)}")
            return None

@st.cache_data
def _cached_preprocess_image(image_bytes, model_name):
    """Cached version of image preprocessing"""
    try:
        image = Image.open(image_bytes)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Adjust size based on model requirements
        if "convnext" in model_name:
            width, height = 384, 384
        else:
            width, height = 224, 224
            
        image = image.resize((width, height), Image.Resampling.LANCZOS)
        return image
    except Exception as e:
        logger.error(f"Image preprocessing error: {e}")
        return None

@st.cache_data      
def get_groq_response(query: str, context: str) -> str:
    """Get response from Groq LLM with caching"""
    try:
        if not os.getenv("GROQ_API_KEY"):
            return "Error: Groq API key not configured"

        client = Groq(api_key=os.getenv("GROQ_API_KEY"))
        
        prompt = f"""Based on the following context about construction defects, answer the question.
        Context: {context}
        Question: {query}
        Provide a detailed answer based on the given context."""

        response = client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": "You are a construction defect analysis expert."
                },
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            model="llama-3.3-70b-versatile",
            temperature=0.7,
        )
        return response.choices[0].message.content
    except Exception as e:
        logger.error(f"Groq API error: {e}", exc_info=True)
        return f"Error: Unable to get response from AI model. Exception: {str(e)}"

def main():
    st.set_page_config(
        page_title="Smart Construction Defect Analyzer",
        page_icon="🏗️",
        layout="wide"
    )
    
    st.title("🏗️ Smart Construction Defect Analyzer")
    
    # Initialize systems in session state if not present
    if 'analyzer' not in st.session_state:
        st.session_state.analyzer = ImageAnalyzer()
    if 'rag_system' not in st.session_state:
        st.session_state.rag_system = RAGSystem()
    
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.subheader("Image Analysis")
        uploaded_file = st.file_uploader(
            "Upload a construction image for analysis",
            type=["jpg", "jpeg", "png"],
            key="image_uploader"  # Add key for proper state management
        )

        if uploaded_file is not None:
            try:
                # Create a placeholder for the image
                image_placeholder = st.empty()
                
                # Process image with progress indicator
                with st.spinner('Processing image...'):
                    processed_image = st.session_state.analyzer.preprocess_image(uploaded_file)
                    if processed_image:
                        image_placeholder.image(processed_image, caption='Uploaded Image', use_container_width=True)
                        
                        # Analyze image with progress bar
                        progress_bar = st.progress(0)
                        with st.spinner('Analyzing defects...'):
                            results = st.session_state.analyzer.analyze_image(processed_image)
                            progress_bar.progress(100)
                        
                        if results:
                            st.success('Analysis complete!')
                            
                            # Display results
                            st.subheader("Detected Defects")
                            fig, ax = plt.subplots(figsize=(8, 4))
                            defects = list(results.keys())
                            probs = list(results.values())
                            ax.barh(defects, probs)
                            ax.set_xlim(0, 1)
                            plt.tight_layout()
                            st.pyplot(fig)
                            
                            most_likely_defect = max(results.items(), key=lambda x: x[1])[0]
                            st.info(f"Most likely defect: {most_likely_defect}")
                        else:
                            st.warning("No defects detected or analysis failed. Please try another image.")
                    else:
                        st.error("Failed to process image. Please try another one.")
                        
            except Exception as e:
                st.error(f"Error processing image: {str(e)}")
                logger.error(f"Process error: {e}")
    
    with col2:
        st.subheader("Ask About Defects")
        user_query = st.text_input(
            "Ask a question about the defects or repairs:",
            help="Example: What are the repair methods for spalling?"
        )
        
        if user_query:
            with st.spinner('Getting answer...'):
                # Get context from RAG system
                context = st.session_state.rag_system.get_relevant_context(user_query)
                
                if context:
                    # Get response from Groq
                    response = get_groq_response(user_query, context)
                    
                    if not response.startswith("Error"):
                        st.write("Answer:")
                        st.markdown(response)
                    else:
                        st.error(response)
                    
                    with st.expander("View retrieved information"):
                        st.text(context)
                else:
                    st.error("Could not find relevant information. Please try rephrasing your question.")

    with st.sidebar:
        st.header("About")
        st.write("""
        This tool helps analyze construction defects in images and provides 
        information about repair methods and best practices.
        
        Features:
        - Image analysis for defect detection
        - Information lookup for repair methods
        - Expert AI responses to your questions
        """)
        
        # Display API status
        if os.getenv("GROQ_API_KEY"):
            st.success("Groq API: Connected")
        else:
            st.error("Groq API: Not configured")
        
        # Add settings section
        st.subheader("Settings")
        if st.button("Clear Cache"):
            st.cache_data.clear()
            st.success("Cache cleared!")

if __name__ == "__main__":
    main()