import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import gradio as gr

# Define your model architecture
class EfficientNetMultiTask(nn.Module):
    def __init__(self, n_area_classes, n_room_classes):
        super(EfficientNetMultiTask, self).__init__()
        self.efficientnet = models.efficientnet_b0(pretrained=False)
        in_features = self.efficientnet.classifier[1].in_features
        self.area_classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, n_area_classes)
        )
        self.room_classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, n_room_classes)
        )
        self.efficientnet.classifier = nn.Identity()

    def forward(self, x):
        features = self.efficientnet(x)
        area_pred = self.area_classifier(features)
        room_pred = self.room_classifier(features)
        return area_pred, room_pred


# Load model
n_area_classes = 21  # Adjust according to your area bins
n_room_classes = 16  # Adjust based on your dataset
model = EfficientNetMultiTask(n_area_classes=n_area_classes, n_room_classes=n_room_classes)

# Load weights (ensure floorplan_model_classification.pth is in the same directory as app.py)
model_weights_path = 'floorplan_model_classification.pth'  # Adjust with your model weights path
model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
model.eval()

# Define transformations
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define area bins
area_bins = [i for i in range(0, 525, 25)]  # [0, 25, ..., 500]
area_bins.append(float('inf'))  # Add infinity for 500+ category

def get_area_from_bin(area_bin_idx):
    if area_bin_idx < len(area_bins) - 2:
        return f"{area_bins[area_bin_idx]} - {area_bins[area_bin_idx + 1]} m²"
    else:
        return f"{area_bins[-2]}+ m²"

# Prediction function
def predict(image):
    image = Image.fromarray(image).convert('RGB')
    image = test_transform(image).unsqueeze(0)

    with torch.no_grad():
        area_output, room_output = model(image)
        area_probabilities = F.softmax(area_output, dim=1)
        room_probabilities = F.softmax(room_output, dim=1)
        area_pred_idx = torch.argmax(area_probabilities, dim=1).item()
        room_pred_idx = torch.argmax(room_probabilities, dim=1).item()
        predicted_area = get_area_from_bin(area_pred_idx)
        predicted_rooms = room_pred_idx + 1  # Adjusting back to original room labels

    return predicted_area, str(predicted_rooms)

# Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy", label="Upload Floor Plan Image"),  # Correct input
    outputs=[
        gr.Textbox(label="Predicted Total Area"),  # Correct output
        gr.Textbox(label="Predicted Number of Rooms")
    ],
    title="Floor Plan Area and Room Predictor",
    description="Upload a floor plan image, and the model will predict the total area range and the number of rooms."
)

# Launch Gradio interface
if __name__ == "__main__":
    interface.launch()