import collections
import heapq
import json
import os
import logging
import faiss
import requests
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms
from PIL import Image
import io
from pathlib import Path
from huggingface_hub import hf_hub_download

log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()

hf_token = os.getenv("HF_TOKEN")

model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"

txt_emb_npy = hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='txt_emb_species.npy', repo_type="dataset")
txt_names_json = "txt_emb_species.json"

min_prob = 1e-9
k = 5

ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

preprocess_img = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)

MIN_PROB = 1e-9
TOP_K_PREDICTIONS = 5
TOP_K_CANDIDATES = 250
TOP_N_SIMILAR = 22
SIMILARITY_BOOST = 0.2
VOTE_THRESHOLD = 3
SIMILARITY_THRESHOLD = 0.99  

# Add paths for RAG
PHOTO_LOOKUP_PATH = f"./photo_lookup.json"
SPECIES_LOOKUP_PATH = f"./species_lookup.json"

theme = gr.themes.Base(
    primary_hue=gr.themes.colors.teal, 
    secondary_hue=gr.themes.colors.blue,
    neutral_hue=gr.themes.colors.gray,
    text_size=gr.themes.sizes.text_lg,
).set(
    button_primary_background_fill="#114A56",
    button_primary_background_fill_hover="#114A56",
    block_title_text_weight="600",
    block_label_text_weight="600",
    block_label_text_size="*text_md",
)

EXAMPLES_DIR = Path("examples")
example_images = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg"))

def indexed(lst, indices):
    return [lst[i] for i in indices]

def format_name(taxon, common):
    taxon = " ".join(taxon)
    if not common:
        return taxon
    return f"{taxon} ({common})"

def combine_duplicate_predictions(predictions):
    """Combine predictions where one name is contained within another."""
    combined = {}
    used = set()
    
    # Sort by length of name (longer names first) and probability
    items = sorted(predictions.items(), key=lambda x: (-len(x[0]), -x[1]))
    
    for name1, prob1 in items:
        if name1 in used:
            continue
            
        total_prob = prob1
        used.add(name1)
        
        # Check remaining predictions
        for name2, prob2 in predictions.items():
            if name2 in used:
                continue
                
            # Convert to lowercase for comparison
            name1_lower = name1.lower()
            name2_lower = name2.lower()
            
            # Check if one name contains the other
            if name1_lower in name2_lower or name2_lower in name1_lower:
                total_prob += prob2
                used.add(name2)
        
        combined[name1] = total_prob
    
    # Normalize probabilities
    total = sum(combined.values())
    return {k: v/total for k, v in combined.items()}

@torch.no_grad()
def open_domain_classification(img, rank: int, return_all=False):
    """
    Predicts from the entire tree of life using RAG approach.
    """
    logger.info(f"Starting open domain classification for rank: {rank}")
    img = preprocess_img(img).to(device)
    img_features = model.encode_image(img.unsqueeze(0))
    img_features = F.normalize(img_features, dim=-1)

    # Get zero-shot predictions
    logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
    probs = F.softmax(logits, dim=0)

    # Get similar images votes and metadata
    species_votes, similar_images = get_similar_images_metadata(img_features, faiss_index, id_mapping, name_mapping)

    if rank + 1 == len(ranks):
        # Species level prediction
        topk = probs.topk(TOP_K_CANDIDATES)
        predictions = {
            format_name(*txt_names[i]): prob.item()
            for i, prob in zip(topk.indices, topk.values)
        }
        
        # Augment predictions with votes
        augmented_predictions = predictions.copy()
        for pred_name in predictions:
            pred_name_lower = pred_name.lower()
            for voted_species, vote_count in species_votes.items():
                if voted_species in pred_name_lower or pred_name_lower in voted_species:
                    augmented_predictions[pred_name] += SIMILARITY_BOOST * vote_count
                elif vote_count >= VOTE_THRESHOLD:
                    augmented_predictions[voted_species] = vote_count * SIMILARITY_BOOST

        # Sort predictions
        sorted_predictions = dict(sorted(
            augmented_predictions.items(), 
            key=lambda x: x[1], 
            reverse=True
        )[:k])
        
        # Normalize and combine duplicates
        total = sum(sorted_predictions.values())
        sorted_predictions = {k: v/total for k, v in sorted_predictions.items()}
        sorted_predictions = combine_duplicate_predictions(sorted_predictions)
        
        logger.info(f"Top K predictions after combining duplicates: {sorted_predictions}")
        return sorted_predictions, similar_images

    # Higher rank prediction
    output = collections.defaultdict(float)
    for i in torch.nonzero(probs > MIN_PROB).squeeze():
        output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]

    # Incorporate votes for higher ranks
    for species, vote_count in species_votes.items():
        try:
            # Find matching taxonomy in txt_names
            for taxonomy, _ in txt_names:
                if species in " ".join(taxonomy).lower():
                    higher_rank = " ".join(taxonomy[: rank + 1])
                    output[higher_rank] += SIMILARITY_BOOST * vote_count
                    break
        except Exception as e:
            logger.error(f"Error processing vote for species {species}: {e}")

    # Get top-k predictions and normalize
    topk_names = heapq.nlargest(k, output, key=output.get)
    prediction_dict = {name: output[name] for name in topk_names}
    
    # Normalize probabilities to sum to 1
    total = sum(prediction_dict.values())
    prediction_dict = {k: v/total for k, v in prediction_dict.items()}
    prediction_dict = combine_duplicate_predictions(prediction_dict)
    
    logger.info(f"Prediction dictionary after combining duplicates: {prediction_dict}")

    return prediction_dict, similar_images


def change_output(choice):
    return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)

def get_cache_paths(name="demo"):
    """Get paths for cached FAISS index and ID mapping."""
    return {
        'index': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo.index', repo_type="dataset"),
        'mapping': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo_mapping.json', repo_type="dataset")
    }

def build_name_mapping(txt_names):
    """Build mapping between scientific names and common names."""
    name_mapping = {}
    for taxonomy, common_name in txt_names:
        if not common_name:
            continue
        if len(taxonomy) >= 2:
            scientific_name = f"{taxonomy[-2]} {taxonomy[-1]}".lower()
            common_name = common_name.lower()
            name_mapping[scientific_name] = (scientific_name, common_name)
            name_mapping[common_name] = (scientific_name, common_name)
    return name_mapping

def load_faiss_index():
    """Load FAISS index from cache."""
    cache_paths = get_cache_paths()
    logger.info("Loading FAISS index from cache...")
    index = faiss.read_index(cache_paths['index'])
    with open(cache_paths['mapping'], 'r') as f:
        id_mapping = json.load(f)
    return index, id_mapping
    
def get_similar_images_metadata(img_embedding, faiss_index, id_mapping, name_mapping):
    """Get metadata for similar images using FAISS search."""
    img_embedding_np = img_embedding.cpu().numpy()
    if img_embedding_np.ndim == 1:
        img_embedding_np = img_embedding_np.reshape(1, -1)
    
    # Search for more images than needed to account for filtered matches
    distances, indices = faiss_index.search(img_embedding_np, TOP_N_SIMILAR * 2)
    
    # Filter out near-exact matches
    valid_indices = []
    valid_distances = []
    valid_count = 0
    
    for dist, idx in zip(distances[0], indices[0]):
        # For inner product similarity, the distance is already the similarity
        similarity = dist
        if similarity > SIMILARITY_THRESHOLD:
            continue
            
        valid_indices.append(idx)
        valid_distances.append(similarity)
        valid_count += 1
        
        if valid_count >= TOP_N_SIMILAR:
            break
    
    species_votes = {}
    similar_images = []
    
    for idx, similarity in zip(valid_indices[:5], valid_distances[:5]):  # Only process top 5 for display
        similar_img_id = id_mapping[idx]

        try:
            species_names = id_to_species_info.get(similar_img_id)
            species_names = [name for name in species_names if name]
            
            processed_names = set()
            for species in species_names:
                if not species:
                    continue
                name_tuple = name_mapping.get(species)
                if name_tuple:
                    processed_names.add(name_tuple[0])
                else:
                    processed_names.add(species)
            
            for species in processed_names:
                species_votes[species] = species_votes.get(species, 0) + 1
            
            # Store similar image info if the image file exists
            # if img_path and os.path.exists(img_path):
            similar_images.append({
                'id': similar_img_id,
                'species': next(iter(processed_names)) if processed_names else 'Unknown',
                'common_name': species_names[-1],
                'similarity': similarity  # Add similarity score
            })
                
        except Exception as e:
            logger.error(f"Error processing JSON for image {similar_img_id}: {e}")
            continue
    
    return species_votes, similar_images


if __name__ == "__main__":
    logger.info("Starting.")
    model = create_model(model_str, output_dict=True, require_pretrained=True)
    model = model.to(device)
    logger.info("Created model.")

    model = torch.compile(model)
    logger.info("Compiled model.")

    tokenizer = get_tokenizer(tokenizer_str)

    id_to_photo_url = json.load(open(PHOTO_LOOKUP_PATH))
    id_to_species_info = json.load(open(SPECIES_LOOKUP_PATH))
    logger.info(f"Loaded {len(id_to_photo_url)} photo mappings")
    logger.info(f"Loaded {len(id_to_species_info)} species mappings")
    # Load text embeddings and build name mapping
    txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
    with open(txt_names_json) as fd:
        txt_names = json.load(fd)
    
    # Build name mapping
    name_mapping = build_name_mapping(txt_names)
    
    # Build or load FAISS index with test IDs
    faiss_index, id_mapping = load_faiss_index()

    # Define process_output function before using it
    def process_output(img, rank):
        predictions, similar_imgs = open_domain_classification(img, rank)
        
        logger.info(f"Number of similar images found: {len(similar_imgs)}")
        
        images = []
        labels = []
        
        for img_info in similar_imgs:
            img_id = img_info['id']
            img_url = id_to_photo_url.get(img_id)
            img_url = img_url.replace("square", "small")
            logger.info(f"Processing image URL: {img_url}")
            
            try:
                # Try fetching from URL first
                response = requests.get(img_url)
                if response.status_code == 200:
                    try:
                        img = Image.open(io.BytesIO(response.content))
                        images.append(img)
                    except Exception as e:
                        logger.info(f"Failed to load image from URL: {e}")
                        images.append(None)
                else:
                    logger.info(f"Failed to fetch image from URL: {response}")
                    images.append(None)
                    
                # Add label regardless of image load success
                label = f"**{img_info['species']}**"
                if img_info['common_name']:
                    label += f" ({img_info['common_name']})"
                label += f"\nSimilarity: {img_info['similarity']:.3f}"
                label += f"\n[View on iNaturalist](https://www.inaturalist.org/observations/{img_id})"
                labels.append(label)
                
            except Exception as e:
                logger.error(f"Error processing image {img_id}: {e}")
                images.append(None)
                labels.append("")

        # Pad arrays if needed
        images += [None] * (5 - len(images))
        labels += [""] * (5 - len(labels))
        
        logger.info(f"Final number of images: {len(images)}")
        logger.info(f"Final number of labels: {len(labels)}")
        
        return [predictions] + images + labels

    with gr.Blocks(theme=theme) as app:
        # Add header
        with gr.Row(variant="panel"):
            with gr.Column(scale=1):
                gr.Image("image.jpg", elem_id="logo-img", 
                        show_label=False )
            with gr.Column(scale=30):
                gr.Markdown("""Biome is a vision foundation model-powered tool customized to identify Singapore's local biodiversity. 
                <br/> <br/> 
                **Developed by**: Pye Sone Kyaw - AI Engineer @ Multimodal AI Team - AI Practice - GovTech SG
                <br/> <br/> 
                Under the hood, Biome is using [BioCLIP](https://github.com/Imageomics/BioCLIP) augmented with multimodal search and retrieval to enhance its Singapore-specific biodiversity classification capabilities.
                <br/> <br/> 
                Biome work best when the organism is clearly visible and takes up a substantial part of the image.
                """)

        with gr.Row(variant="panel", elem_id="images_panel"):
            img_input = gr.Image(
                height=400, 
                sources=["upload"],
                type="pil"
            )
            
        

        with gr.Row():
            
            with gr.Column():
                with gr.Row():
                    gr.Examples(
                        examples=example_images,
                        inputs=img_input,
                            label="Example Images"
                        )
                rank_dropdown = gr.Dropdown(
                    label="Taxonomic Rank",
                    info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
                    choices=ranks,
                    value="Species",
                    type="index",
                )
                open_domain_btn = gr.Button("Submit", variant="primary")
            with gr.Column():
                open_domain_output = gr.Label(
                    num_top_classes=k,
                    label="Prediction",
                    show_label=True,
                    value=None,
                )
        
        # New section for similar images
        with gr.Row(variant="panel"):
            with gr.Column():
                gr.Markdown("### Most Similar Images from Database")
            
        with gr.Row():
            similar_images = [
                gr.Image(label="Similar Image 1", height=200, show_label=True),
                gr.Image(label="Similar Image 2", height=200, show_label=True),
                gr.Image(label="Similar Image 3", height=200, show_label=True),
                gr.Image(label="Similar Image 4", height=200, show_label=True),
                gr.Image(label="Similar Image 5", height=200, show_label=True),
            ]
        
        with gr.Row():
            similar_labels = [
                gr.Markdown("Species 1"),
                gr.Markdown("Species 2"),
                gr.Markdown("Species 3"),
                gr.Markdown("Species 4"),
                gr.Markdown("Species 5"),
            ]
            
        rank_dropdown.change(
            fn=change_output, 
            inputs=rank_dropdown, 
            outputs=[open_domain_output]
        )

        open_domain_btn.click(
            fn=process_output,
            inputs=[img_input, rank_dropdown],
            outputs=[open_domain_output] + similar_images + similar_labels,
        )

        with gr.Row(variant="panel"):
            gr.Markdown("""
            **Disclaimer**: This is a proof-of-concept demo for non-commercial purposes. No data is stored or used for any form of training, and all data used for retrieval are from [iNaturalist](https://inaturalist.org/).
            The adage of garbage in, garbage out applies here - uploading images not biodiversity-related will yield unpredictable results.
            """)
    app.queue(max_size=20)
    app.launch(share=False, enable_monitoring=False, allowed_paths=["/app/"])