# import os
# import numpy as np
# import torch
# import torch.nn as nn
# import gradio as gr
# import time
# from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
# from torchvision.ops import nms, box_iou
# import torch.nn.functional as F
# from torchvision import transforms
# from PIL import Image, ImageDraw, ImageFont, ImageFilter
# from breed_health_info import breed_health_info
# from breed_noise_info import breed_noise_info
# from dog_database import get_dog_description, dog_data
# from scoring_calculation_system import UserPreferences
# from recommendation_html_format import format_recommendation_html, get_breed_recommendations
# from history_manager import UserHistoryManager
# from search_history import create_history_tab, create_history_component
# from styles import get_css_styles
# from breed_detection import create_detection_tab
# from breed_comparison import create_comparison_tab
# from breed_recommendation import create_recommendation_tab
# from html_templates import (
#     format_description_html,
#     format_single_dog_result,
#     format_multiple_breeds_result,
#     format_error_message,
#     format_warning_html,
#     format_multi_dog_container,
#     format_breed_details_html,
#     get_color_scheme,
#     get_akc_breeds_link
# )
# from urllib.parse import quote
# from ultralytics import YOLO
# import asyncio
# import traceback


# model_yolo = YOLO('yolov8l.pt')

# history_manager = UserHistoryManager()

# dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
#               "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise",
#               "Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
#               "Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
#               "Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
#               "English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
#               "German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
#               "Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
#               "Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
#               "Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
#               "Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
#               "Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
#               "Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
#               "Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu",
#               "Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
#               "Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
#               "Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
#               "Affenpinscher", "Basenji", "Basset", "Beagle", "Black-and-Tan_Coonhound", "Bloodhound",
#               "Bluetick", "Borzoi", "Boxer", "Briard", "Bull_Mastiff", "Cairn", "Chow", "Clumber",
#               "Cocker_Spaniel", "Collie", "Curly-Coated_Retriever", "Dhole", "Dingo",
#               "Flat-Coated_Retriever", "Giant_Schnauzer", "Golden_Retriever", "Groenendael", "Keeshond",
#               "Kelpie", "Komondor", "Kuvasz", "Malamute", "Malinois", "Miniature_Pinscher",
#               "Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
#               "Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
#               "Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
#               "Wire-Haired_Fox_Terrier"]


# class MultiHeadAttention(nn.Module):

#     def __init__(self, in_dim, num_heads=8):
#         super().__init__()
#         self.num_heads = num_heads
#         self.head_dim = max(1, in_dim // num_heads)
#         self.scaled_dim = self.head_dim * num_heads
#         self.fc_in = nn.Linear(in_dim, self.scaled_dim)
#         self.query = nn.Linear(self.scaled_dim, self.scaled_dim)
#         self.key = nn.Linear(self.scaled_dim, self.scaled_dim)
#         self.value = nn.Linear(self.scaled_dim, self.scaled_dim)
#         self.fc_out = nn.Linear(self.scaled_dim, in_dim)

#     def forward(self, x):
#         N = x.shape[0]
#         x = self.fc_in(x)
#         q = self.query(x).view(N, self.num_heads, self.head_dim)
#         k = self.key(x).view(N, self.num_heads, self.head_dim)
#         v = self.value(x).view(N, self.num_heads, self.head_dim)

#         energy = torch.einsum("nqd,nkd->nqk", [q, k])
#         attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2)

#         out = torch.einsum("nqk,nvd->nqd", [attention, v])
#         out = out.reshape(N, self.scaled_dim)
#         out = self.fc_out(out)
#         return out

# class BaseModel(nn.Module):
#     def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
#         super().__init__()
#         self.device = device
#         self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
#         self.feature_dim = self.backbone.classifier[1].in_features
#         self.backbone.classifier = nn.Identity()

#         self.num_heads = max(1, min(8, self.feature_dim // 64))
#         self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)

#         self.classifier = nn.Sequential(
#             nn.LayerNorm(self.feature_dim),
#             nn.Dropout(0.3),
#             nn.Linear(self.feature_dim, num_classes)
#         )

#         self.to(device)

#     def forward(self, x):
#         x = x.to(self.device)
#         features = self.backbone(x)
#         attended_features = self.attention(features)
#         logits = self.classifier(attended_features)
#         return logits, attended_features

# # Initialize model
# num_classes = len(dog_breeds)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Initialize base model
# model = BaseModel(num_classes=num_classes, device=device).to(device)

# # Load model path
# model_path = "124_best_model_dog.pth"
# checkpoint = torch.load(model_path, map_location=device)

# # Load model state
# model.load_state_dict(checkpoint["base_model"], strict=False)
# model.eval()

# # Image preprocessing function
# def preprocess_image(image):
#     # If the image is numpy.ndarray turn into PIL.Image
#     if isinstance(image, np.ndarray):
#         image = Image.fromarray(image)

#     # Use torchvision.transforms to process images
#     transform = transforms.Compose([
#         transforms.Resize((224, 224)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     return transform(image).unsqueeze(0)

# async def predict_single_dog(image):
#     """
#     Predicts the dog breed using only the classifier.
#     Args:
#         image: PIL Image or numpy array
#     Returns:
#         tuple: (top1_prob, topk_breeds, relative_probs)
#     """
#     image_tensor = preprocess_image(image).to(device)
    
#     with torch.no_grad():
#         # Get model outputs (只使用logits,不需要features)
#         logits = model(image_tensor)[0]  # 如果model仍返回tuple,取第一個元素
#         probs = F.softmax(logits, dim=1)
        
#         # Classifier prediction
#         top5_prob, top5_idx = torch.topk(probs, k=5)
#         breeds = [dog_breeds[idx.item()] for idx in top5_idx[0]]
#         probabilities = [prob.item() for prob in top5_prob[0]]
        
#         # Calculate relative probabilities
#         sum_probs = sum(probabilities[:3])  # 只取前三個來計算相對概率
#         relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in probabilities[:3]]
        
#         # Debug output
#         print("\nClassifier Predictions:")
#         for breed, prob in zip(breeds[:5], probabilities[:5]):
#             print(f"{breed}: {prob:.4f}")
            
#         return probabilities[0], breeds[:3], relative_probs


# async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
#     results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
#     dogs = []
#     boxes = []
#     for box in results.boxes:
#         if box.cls == 16:  # COCO dataset class for dog is 16
#             xyxy = box.xyxy[0].tolist()
#             confidence = box.conf.item()
#             boxes.append((xyxy, confidence))

#     if not boxes:
#         dogs.append((image, 1.0, [0, 0, image.width, image.height]))
#     else:
#         nms_boxes = non_max_suppression(boxes, iou_threshold)

#         for box, confidence in nms_boxes:
#             x1, y1, x2, y2 = box
#             w, h = x2 - x1, y2 - y1
#             x1 = max(0, x1 - w * 0.05)
#             y1 = max(0, y1 - h * 0.05)
#             x2 = min(image.width, x2 + w * 0.05)
#             y2 = min(image.height, y2 + h * 0.05)
#             cropped_image = image.crop((x1, y1, x2, y2))
#             dogs.append((cropped_image, confidence, [x1, y1, x2, y2]))

#     return dogs

# def non_max_suppression(boxes, iou_threshold):
#     keep = []
#     boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
#     while boxes:
#         current = boxes.pop(0)
#         keep.append(current)
#         boxes = [box for box in boxes if calculate_iou(current[0], box[0]) < iou_threshold]
#     return keep


# def calculate_iou(box1, box2):
#     x1 = max(box1[0], box2[0])
#     y1 = max(box1[1], box2[1])
#     x2 = min(box1[2], box2[2])
#     y2 = min(box1[3], box2[3])

#     intersection = max(0, x2 - x1) * max(0, y2 - y1)
#     area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
#     area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

#     iou = intersection / float(area1 + area2 - intersection)
#     return iou


# def create_breed_comparison(breed1: str, breed2: str) -> dict:
#     breed1_info = get_dog_description(breed1)
#     breed2_info = get_dog_description(breed2)

#     # 標準化數值轉換
#     value_mapping = {
#         'Size': {'Small': 1, 'Medium': 2, 'Large': 3, 'Giant': 4},
#         'Exercise_Needs': {'Low': 1, 'Moderate': 2, 'High': 3, 'Very High': 4},
#         'Care_Level': {'Low': 1, 'Moderate': 2, 'High': 3},
#         'Grooming_Needs': {'Low': 1, 'Moderate': 2, 'High': 3}
#     }

#     comparison_data = {
#         breed1: {},
#         breed2: {}
#     }

#     for breed, info in [(breed1, breed1_info), (breed2, breed2_info)]:
#         comparison_data[breed] = {
#             'Size': value_mapping['Size'].get(info['Size'], 2),  # 預設 Medium
#             'Exercise_Needs': value_mapping['Exercise_Needs'].get(info['Exercise Needs'], 2),  # 預設 Moderate
#             'Care_Level': value_mapping['Care_Level'].get(info['Care Level'], 2),
#             'Grooming_Needs': value_mapping['Grooming_Needs'].get(info['Grooming Needs'], 2),
#             'Good_with_Children': info['Good with Children'] == 'Yes',
#             'Original_Data': info
#         }

#     return comparison_data


# async def predict(image):
#     """
#     Main prediction function that handles both single and multiple dog detection.

#     Args:
#         image: PIL Image or numpy array

#     Returns:
#         tuple: (html_output, annotated_image, initial_state)
#     """
#     if image is None:
#         return format_warning_html("Please upload an image to start."), None, None

#     try:
#         if isinstance(image, np.ndarray):
#             image = Image.fromarray(image)

#         # Detect dogs in the image
#         dogs = await detect_multiple_dogs(image)
#         color_scheme = get_color_scheme(len(dogs) == 1)

#         # Prepare for annotation
#         annotated_image = image.copy()
#         draw = ImageDraw.Draw(annotated_image)

#         try:
#             font = ImageFont.truetype("arial.ttf", 24)
#         except:
#             font = ImageFont.load_default()

#         dogs_info = ""

#         # Process each detected dog
#         for i, (cropped_image, detection_confidence, box) in enumerate(dogs):
#             color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]

#             # Draw box and label on image
#             draw.rectangle(box, outline=color, width=4)
#             label = f"Dog {i+1}"
#             label_bbox = draw.textbbox((0, 0), label, font=font)
#             label_width = label_bbox[2] - label_bbox[0]
#             label_height = label_bbox[3] - label_bbox[1]

#             # Draw label background and text
#             label_x = box[0] + 5
#             label_y = box[1] + 5
#             draw.rectangle(
#                 [label_x - 2, label_y - 2, label_x + label_width + 4, label_y + label_height + 4],
#                 fill='white',
#                 outline=color,
#                 width=2
#             )
#             draw.text((label_x, label_y), label, fill=color, font=font)

#             # Predict breed
#             top1_prob, topk_breeds, relative_probs = await predict_single_dog(cropped_image)
#             combined_confidence = detection_confidence * top1_prob

#             # Format results based on confidence with error handling
#             try:
#                 if combined_confidence < 0.2:
#                     dogs_info += format_error_message(color, i+1)
#                 elif top1_prob >= 0.45:
#                     breed = topk_breeds[0]
#                     description = get_dog_description(breed)
#                     # Handle missing breed description
#                     if description is None:
#                         # 如果沒有描述,創建一個基本描述
#                         description = {
#                             "Name": breed,
#                             "Size": "Unknown",
#                             "Exercise Needs": "Unknown",
#                             "Grooming Needs": "Unknown",
#                             "Care Level": "Unknown",
#                             "Good with Children": "Unknown",
#                             "Description": f"Identified as {breed.replace('_', ' ')}"
#                         }
#                     dogs_info += format_single_dog_result(breed, description, color)
#                 else:
#                     # 修改format_multiple_breeds_result的調用,包含錯誤處理
#                     dogs_info += format_multiple_breeds_result(
#                         topk_breeds,
#                         relative_probs,
#                         color,
#                         i+1,
#                         lambda breed: get_dog_description(breed) or {
#                             "Name": breed,
#                             "Size": "Unknown",
#                             "Exercise Needs": "Unknown",
#                             "Grooming Needs": "Unknown",
#                             "Care Level": "Unknown",
#                             "Good with Children": "Unknown",
#                             "Description": f"Identified as {breed.replace('_', ' ')}"
#                         }
#                     )
#             except Exception as e:
#                 print(f"Error formatting results for dog {i+1}: {str(e)}")
#                 dogs_info += format_error_message(color, i+1)

#         # Wrap final HTML output
#         html_output = format_multi_dog_container(dogs_info)

#         # Prepare initial state
#         initial_state = {
#             "dogs_info": dogs_info,
#             "image": annotated_image,
#             "is_multi_dog": len(dogs) > 1,
#             "html_output": html_output
#         }

#         return html_output, annotated_image, initial_state

#     except Exception as e:
#         error_msg = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
#         print(error_msg)
#         return format_warning_html(error_msg), None, None


# def show_details_html(choice, previous_output, initial_state):
#     """
#     Generate detailed HTML view for a selected breed.

#     Args:
#         choice: str, Selected breed option
#         previous_output: str, Previous HTML output
#         initial_state: dict, Current state information

#     Returns:
#         tuple: (html_output, gradio_update, updated_state)
#     """
#     if not choice:
#         return previous_output, gr.update(visible=True), initial_state

#     try:
#         breed = choice.split("More about ")[-1]
#         description = get_dog_description(breed)
#         html_output = format_breed_details_html(description, breed)

#         # Update state
#         initial_state["current_description"] = html_output
#         initial_state["original_buttons"] = initial_state.get("buttons", [])

#         return html_output, gr.update(visible=True), initial_state

#     except Exception as e:
#         error_msg = f"An error occurred while showing details: {e}"
#         print(error_msg)
#         return format_warning_html(error_msg), gr.update(visible=True), initial_state

# def main():
#     with gr.Blocks(css=get_css_styles()) as iface:
#         # Header HTML

#         gr.HTML("""
#         <header style='text-align: center; padding: 20px; margin-bottom: 20px;'>
#             <h1 style='font-size: 2.5em; margin-bottom: 10px; color: #2D3748;'>
#                 🐾 PawMatch AI
#             </h1>
#             <h2 style='font-size: 1.2em; font-weight: normal; color: #4A5568; margin-top: 5px;'>
#                 Your Smart Dog Breed Guide
#             </h2>
#             <div style='width: 50px; height: 3px; background: linear-gradient(90deg, #4299e1, #48bb78); margin: 15px auto;'></div>
#             <p style='color: #718096; font-size: 0.9em;'>
#                 Powered by AI • Breed Recognition • Smart Matching • Companion Guide
#             </p>
#         </header>
#         """)

#         # 先創建歷史組件實例(但不創建標籤頁)
#         history_component = create_history_component()

#         with gr.Tabs():
#             # 1. 品種檢測標籤頁
#             example_images = [
#                 'Border_Collie.jpg',
#                 'Golden_Retriever.jpeg',
#                 'Saint_Bernard.jpeg',
#                 'Samoyed.jpg',
#                 'French_Bulldog.jpeg'
#             ]
#             detection_components = create_detection_tab(predict, example_images)

#             # 2. 品種比較標籤頁
#             comparison_components = create_comparison_tab(
#                 dog_breeds=dog_breeds,
#                 get_dog_description=get_dog_description,
#                 breed_health_info=breed_health_info,
#                 breed_noise_info=breed_noise_info
#             )

#             # 3. 品種推薦標籤頁
#             recommendation_components = create_recommendation_tab(
#                 UserPreferences=UserPreferences,
#                 get_breed_recommendations=get_breed_recommendations,
#                 format_recommendation_html=format_recommendation_html,
#                 history_component=history_component
#             )


#             # 4. 最後創建歷史記錄標籤頁
#             create_history_tab(history_component)

#         # Footer
#         gr.HTML('''
#             <div style="
#                 display: flex;
#                 align-items: center;
#                 justify-content: center;
#                 gap: 20px;
#                 padding: 20px 0;
#             ">
#                 <p style="
#                     font-family: 'Arial', sans-serif;
#                     font-size: 14px;
#                     font-weight: 500;
#                     letter-spacing: 2px;
#                     background: linear-gradient(90deg, #555, #007ACC);
#                     -webkit-background-clip: text;
#                     -webkit-text-fill-color: transparent;
#                     margin: 0;
#                     text-transform: uppercase;
#                     display: inline-block;
#                 ">EXPLORE THE CODE →</p>
#                 <a href="https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/PawMatchAI" style="text-decoration: none;">
#                     <img src="https://img.shields.io/badge/GitHub-PawMatch_AI-007ACC?logo=github&style=for-the-badge">
#                 </a>
#             </div>
#         ''')

#     return iface

# if __name__ == "__main__":
#     iface = main()
#     iface.launch()

import os
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
import time
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
from torchvision.ops import nms, box_iou
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont, ImageFilter
from breed_health_info import breed_health_info
from breed_noise_info import breed_noise_info
from dog_database import get_dog_description
from scoring_calculation_system import UserPreferences
from recommendation_html_format import format_recommendation_html, get_breed_recommendations
from history_manager import UserHistoryManager
from search_history import create_history_tab, create_history_component
from styles import get_css_styles
from breed_detection import create_detection_tab
from breed_comparison import create_comparison_tab
from breed_recommendation import create_recommendation_tab
from html_templates import (
    format_description_html,
    format_single_dog_result,
    format_multiple_breeds_result,
    format_error_message,
    format_warning_html,
    format_multi_dog_container,
    format_breed_details_html,
    get_color_scheme,
    get_akc_breeds_link
)
from urllib.parse import quote
from ultralytics import YOLO
import asyncio
import traceback

def get_device():
    if torch.cuda.is_available():
        print('Using CUDA GPU')
        return torch.device('cuda')
    else:
        print('Using CPU')
        return torch.device('cpu')

device = get_device()

history_manager = UserHistoryManager()

dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
              "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise",
              "Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
              "Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
              "Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
              "English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
              "German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
              "Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
              "Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
              "Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
              "Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
              "Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
              "Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
              "Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu",
              "Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
              "Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
              "Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
              "Affenpinscher", "Basenji", "Basset", "Beagle", "Black-and-Tan_Coonhound", "Bloodhound",
              "Bluetick", "Borzoi", "Boxer", "Briard", "Bull_Mastiff", "Cairn", "Chow", "Clumber",
              "Cocker_Spaniel", "Collie", "Curly-Coated_Retriever", "Dhole", "Dingo",
              "Flat-Coated_Retriever", "Giant_Schnauzer", "Golden_Retriever", "Groenendael", "Keeshond",
              "Kelpie", "Komondor", "Kuvasz", "Malamute", "Malinois", "Miniature_Pinscher",
              "Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
              "Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
              "Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
              "Wire-Haired_Fox_Terrier"]


class MultiHeadAttention(nn.Module):

    def __init__(self, in_dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = max(1, in_dim // num_heads)
        self.scaled_dim = self.head_dim * num_heads
        self.fc_in = nn.Linear(in_dim, self.scaled_dim)
        self.query = nn.Linear(self.scaled_dim, self.scaled_dim)
        self.key = nn.Linear(self.scaled_dim, self.scaled_dim)
        self.value = nn.Linear(self.scaled_dim, self.scaled_dim)
        self.fc_out = nn.Linear(self.scaled_dim, in_dim)

    def forward(self, x):
        N = x.shape[0]
        x = self.fc_in(x)
        q = self.query(x).view(N, self.num_heads, self.head_dim)
        k = self.key(x).view(N, self.num_heads, self.head_dim)
        v = self.value(x).view(N, self.num_heads, self.head_dim)

        energy = torch.einsum("nqd,nkd->nqk", [q, k])
        attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2)

        out = torch.einsum("nqk,nvd->nqd", [attention, v])
        out = out.reshape(N, self.scaled_dim)
        out = self.fc_out(out)
        return out

class BaseModel(nn.Module):
    def __init__(self, num_classes, device=None):
        super().__init__()
        if device is None:
            device = get_device()
        self.device = device
        print(f"Initializing model on device: {device}")

        self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
        self.feature_dim = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()

        self.num_heads = max(1, min(8, self.feature_dim // 64))
        self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)

        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Dropout(0.3),
            nn.Linear(self.feature_dim, num_classes)
        )

        self.to(device)

    def forward(self, x):
        x = x.to(self.device)
        features = self.backbone(x)
        attended_features = self.attention(features)
        logits = self.classifier(attended_features)
        return logits, attended_features

# Initialize model
num_classes = len(dog_breeds)

# Initialize base model
model = BaseModel(num_classes=num_classes, device=device)

# Load model path
model_path = '/content/drive/Othercomputers/我的 MacBook Pro/Learning/Cats_Dogs_Detector/(124_TEST)_models/[124_82.30]_best_model_dog.pth'
checkpoint = torch.load(model_path, map_location=device)

# Load model state
model.load_state_dict(checkpoint['base_model'], strict=False)
model.eval()

# Image preprocessing function
def preprocess_image(image):
    # If the image is numpy.ndarray turn into PIL.Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    # Use torchvision.transforms to process images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return transform(image).unsqueeze(0)


model_yolo = YOLO('yolov8l.pt')
if torch.cuda.is_available():
    model_yolo.to(device)

async def predict_single_dog(image):
    """
    Predicts the dog breed using only the classifier.
    Args:
        image: PIL Image or numpy array
    Returns:
        tuple: (top1_prob, topk_breeds, relative_probs)
    """
    image_tensor = preprocess_image(image).to(device)

    with torch.no_grad():
        # Get model outputs (只使用logits,不需要features)
        logits = model(image_tensor)[0]  # 如果model仍返回tuple,取第一個元素
        probs = F.softmax(logits, dim=1)

        # Classifier prediction
        top5_prob, top5_idx = torch.topk(probs, k=5)
        breeds = [dog_breeds[idx.item()] for idx in top5_idx[0]]
        probabilities = [prob.item() for prob in top5_prob[0]]

        # Calculate relative probabilities
        sum_probs = sum(probabilities[:3])  # 只取前三個來計算相對概率
        relative_probs = [f"{(prob/sum_probs * 100):.2f}%" for prob in probabilities[:3]]

        # Debug output
        print("\nClassifier Predictions:")
        for breed, prob in zip(breeds[:5], probabilities[:5]):
            print(f"{breed}: {prob:.4f}")

        return probabilities[0], breeds[:3], relative_probs


async def detect_multiple_dogs(image, conf_threshold=0.3, iou_threshold=0.55):
    results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
    dogs = []
    boxes = []
    for box in results.boxes:
        if box.cls == 16:  # COCO dataset class for dog is 16
            xyxy = box.xyxy[0].tolist()
            confidence = box.conf.item()
            boxes.append((xyxy, confidence))

    if not boxes:
        dogs.append((image, 1.0, [0, 0, image.width, image.height]))
    else:
        nms_boxes = non_max_suppression(boxes, iou_threshold)

        for box, confidence in nms_boxes:
            x1, y1, x2, y2 = box
            w, h = x2 - x1, y2 - y1
            x1 = max(0, x1 - w * 0.05)
            y1 = max(0, y1 - h * 0.05)
            x2 = min(image.width, x2 + w * 0.05)
            y2 = min(image.height, y2 + h * 0.05)
            cropped_image = image.crop((x1, y1, x2, y2))
            dogs.append((cropped_image, confidence, [x1, y1, x2, y2]))

    return dogs

def non_max_suppression(boxes, iou_threshold):
    keep = []
    boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
    while boxes:
        current = boxes.pop(0)
        keep.append(current)
        boxes = [box for box in boxes if calculate_iou(current[0], box[0]) < iou_threshold]
    return keep


def calculate_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = intersection / float(area1 + area2 - intersection)
    return iou



def create_breed_comparison(breed1: str, breed2: str) -> dict:
    breed1_info = get_dog_description(breed1)
    breed2_info = get_dog_description(breed2)

    # 標準化數值轉換
    value_mapping = {
        'Size': {'Small': 1, 'Medium': 2, 'Large': 3, 'Giant': 4},
        'Exercise_Needs': {'Low': 1, 'Moderate': 2, 'High': 3, 'Very High': 4},
        'Care_Level': {'Low': 1, 'Moderate': 2, 'High': 3},
        'Grooming_Needs': {'Low': 1, 'Moderate': 2, 'High': 3}
    }

    comparison_data = {
        breed1: {},
        breed2: {}
    }

    for breed, info in [(breed1, breed1_info), (breed2, breed2_info)]:
        comparison_data[breed] = {
            'Size': value_mapping['Size'].get(info['Size'], 2),  # 預設 Medium
            'Exercise_Needs': value_mapping['Exercise_Needs'].get(info['Exercise Needs'], 2),  # 預設 Moderate
            'Care_Level': value_mapping['Care_Level'].get(info['Care Level'], 2),
            'Grooming_Needs': value_mapping['Grooming_Needs'].get(info['Grooming Needs'], 2),
            'Good_with_Children': info['Good with Children'] == 'Yes',
            'Original_Data': info
        }

    return comparison_data


async def predict(image):
    """
    Main prediction function that handles both single and multiple dog detection.

    Args:
        image: PIL Image or numpy array

    Returns:
        tuple: (html_output, annotated_image, initial_state)
    """
    if image is None:
        return format_warning_html("Please upload an image to start."), None, None

    try:
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        # Detect dogs in the image
        dogs = await detect_multiple_dogs(image)
        color_scheme = get_color_scheme(len(dogs) == 1)

        # Prepare for annotation
        annotated_image = image.copy()
        draw = ImageDraw.Draw(annotated_image)

        try:
            font = ImageFont.truetype("arial.ttf", 24)
        except:
            font = ImageFont.load_default()

        dogs_info = ""

        # Process each detected dog
        for i, (cropped_image, detection_confidence, box) in enumerate(dogs):
            color = color_scheme if len(dogs) == 1 else color_scheme[i % len(color_scheme)]

            # Draw box and label on image
            draw.rectangle(box, outline=color, width=4)
            label = f"Dog {i+1}"
            label_bbox = draw.textbbox((0, 0), label, font=font)
            label_width = label_bbox[2] - label_bbox[0]
            label_height = label_bbox[3] - label_bbox[1]

            # Draw label background and text
            label_x = box[0] + 5
            label_y = box[1] + 5
            draw.rectangle(
                [label_x - 2, label_y - 2, label_x + label_width + 4, label_y + label_height + 4],
                fill='white',
                outline=color,
                width=2
            )
            draw.text((label_x, label_y), label, fill=color, font=font)

            # Predict breed
            top1_prob, topk_breeds, relative_probs = await predict_single_dog(cropped_image)
            combined_confidence = detection_confidence * top1_prob

            # Format results based on confidence with error handling
            try:
                if combined_confidence < 0.2:
                    dogs_info += format_error_message(color, i+1)
                elif top1_prob >= 0.45:
                    breed = topk_breeds[0]
                    description = get_dog_description(breed)
                    # Handle missing breed description
                    if description is None:
                        # 如果沒有描述,創建一個基本描述
                        description = {
                            "Name": breed,
                            "Size": "Unknown",
                            "Exercise Needs": "Unknown",
                            "Grooming Needs": "Unknown",
                            "Care Level": "Unknown",
                            "Good with Children": "Unknown",
                            "Description": f"Identified as {breed.replace('_', ' ')}"
                        }
                    dogs_info += format_single_dog_result(breed, description, color)
                else:
                    # 修改format_multiple_breeds_result的調用,包含錯誤處理
                    dogs_info += format_multiple_breeds_result(
                        topk_breeds,
                        relative_probs,
                        color,
                        i+1,
                        lambda breed: get_dog_description(breed) or {
                            "Name": breed,
                            "Size": "Unknown",
                            "Exercise Needs": "Unknown",
                            "Grooming Needs": "Unknown",
                            "Care Level": "Unknown",
                            "Good with Children": "Unknown",
                            "Description": f"Identified as {breed.replace('_', ' ')}"
                        }
                    )
            except Exception as e:
                print(f"Error formatting results for dog {i+1}: {str(e)}")
                dogs_info += format_error_message(color, i+1)

        # Wrap final HTML output
        html_output = format_multi_dog_container(dogs_info)

        # Prepare initial state
        initial_state = {
            "dogs_info": dogs_info,
            "image": annotated_image,
            "is_multi_dog": len(dogs) > 1,
            "html_output": html_output
        }

        return html_output, annotated_image, initial_state

    except Exception as e:
        error_msg = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_msg)
        return format_warning_html(error_msg), None, None


def show_details_html(choice, previous_output, initial_state):
    """
    Generate detailed HTML view for a selected breed.

    Args:
        choice: str, Selected breed option
        previous_output: str, Previous HTML output
        initial_state: dict, Current state information

    Returns:
        tuple: (html_output, gradio_update, updated_state)
    """
    if not choice:
        return previous_output, gr.update(visible=True), initial_state

    try:
        breed = choice.split("More about ")[-1]
        description = get_dog_description(breed)
        html_output = format_breed_details_html(description, breed)

        # Update state
        initial_state["current_description"] = html_output
        initial_state["original_buttons"] = initial_state.get("buttons", [])

        return html_output, gr.update(visible=True), initial_state

    except Exception as e:
        error_msg = f"An error occurred while showing details: {e}"
        print(error_msg)
        return format_warning_html(error_msg), gr.update(visible=True), initial_state

def main():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    with gr.Blocks(css=get_css_styles()) as iface:
        # Header HTML

        gr.HTML("""
        <header style='text-align: center; padding: 20px; margin-bottom: 20px;'>
            <h1 style='font-size: 2.5em; margin-bottom: 10px; color: #2D3748;'>
                🐾 PawMatch AI
            </h1>
            <h2 style='font-size: 1.2em; font-weight: normal; color: #4A5568; margin-top: 5px;'>
                Your Smart Dog Breed Guide
            </h2>
            <div style='width: 50px; height: 3px; background: linear-gradient(90deg, #4299e1, #48bb78); margin: 15px auto;'></div>
            <p style='color: #718096; font-size: 0.9em;'>
                Powered by AI • Breed Recognition • Smart Matching • Companion Guide
            </p>
        </header>
        """)

        # 先創建歷史組件實例(但不創建標籤頁)
        history_component = create_history_component()

        with gr.Tabs():
            # 1. 品種檢測標籤頁
            example_images = [
                '/content/drive/Othercomputers/我的 MacBook Pro/Learning/Cats_Dogs_Detector/test_images/Border_Collie.jpg',
                '/content/drive/Othercomputers/我的 MacBook Pro/Learning/Cats_Dogs_Detector/test_images/Golden_Retriever.jpeg',
                '/content/drive/Othercomputers/我的 MacBook Pro/Learning/Cats_Dogs_Detector/test_images/Saint_Bernard.jpeg',
                '/content/drive/Othercomputers/我的 MacBook Pro/Learning/Cats_Dogs_Detector/test_images/Samoyed.jpg',
                '/content/drive/Othercomputers/我的 MacBook Pro/Learning/Cats_Dogs_Detector/test_images/French_Bulldog.jpeg'
            ]
            detection_components = create_detection_tab(predict, example_images)

            # 2. 品種比較標籤頁
            comparison_components = create_comparison_tab(
                dog_breeds=dog_breeds,
                get_dog_description=get_dog_description,
                breed_health_info=breed_health_info,
                breed_noise_info=breed_noise_info
            )

            # 3. 品種推薦標籤頁
            recommendation_components = create_recommendation_tab(
                UserPreferences=UserPreferences,
                get_breed_recommendations=get_breed_recommendations,
                format_recommendation_html=format_recommendation_html,
                history_component=history_component
            )


            # 4. 最後創建歷史記錄標籤頁
            create_history_tab(history_component)

        # Footer
        gr.HTML('''
            <div style="
                display: flex;
                align-items: center;
                justify-content: center;
                gap: 20px;
                padding: 20px 0;
            ">
                <p style="
                    font-family: 'Arial', sans-serif;
                    font-size: 14px;
                    font-weight: 500;
                    letter-spacing: 2px;
                    background: linear-gradient(90deg, #555, #007ACC);
                    -webkit-background-clip: text;
                    -webkit-text-fill-color: transparent;
                    margin: 0;
                    text-transform: uppercase;
                    display: inline-block;
                ">EXPLORE THE CODE →</p>
                <a href="https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/PawMatchAI" style="text-decoration: none;">
                    <img src="https://img.shields.io/badge/GitHub-PawMatch_AI-007ACC?logo=github&style=for-the-badge">
                </a>
            </div>
        ''')

    return iface

if __name__ == "__main__":
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"Current device: {torch.cuda.current_device()}")
        print(f"Device name: {torch.cuda.get_device_name()}")
    iface = main()
    iface.launch(share=True, debug=True)