SciOner's picture
init commit
8601eb3
raw
history blame
3.43 kB
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import gradio as gr
# Define your model architecture
class EfficientNetMultiTask(nn.Module):
def __init__(self, n_area_classes, n_room_classes):
super(EfficientNetMultiTask, self).__init__()
self.efficientnet = models.efficientnet_b0(pretrained=False)
in_features = self.efficientnet.classifier[1].in_features
self.area_classifier = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, n_area_classes)
)
self.room_classifier = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, n_room_classes)
)
self.efficientnet.classifier = nn.Identity()
def forward(self, x):
features = self.efficientnet(x)
area_pred = self.area_classifier(features)
room_pred = self.room_classifier(features)
return area_pred, room_pred
# Load model
n_area_classes = 21 # Adjust according to your area bins
n_room_classes = 16 # Adjust based on your dataset
model = EfficientNetMultiTask(n_area_classes=n_area_classes, n_room_classes=n_room_classes)
# Load weights (ensure floorplan_model_classification.pth is in the same directory as app.py)
model_weights_path = 'floorplan_model_classification.pth' # Adjust with your model weights path
model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
model.eval()
# Define transformations
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define area bins
area_bins = [i for i in range(0, 525, 25)] # [0, 25, ..., 500]
area_bins.append(float('inf')) # Add infinity for 500+ category
def get_area_from_bin(area_bin_idx):
if area_bin_idx < len(area_bins) - 2:
return f"{area_bins[area_bin_idx]} - {area_bins[area_bin_idx + 1]} m²"
else:
return f"{area_bins[-2]}+ m²"
# Prediction function
def predict(image):
image = Image.fromarray(image).convert('RGB')
image = test_transform(image).unsqueeze(0)
with torch.no_grad():
area_output, room_output = model(image)
area_probabilities = F.softmax(area_output, dim=1)
room_probabilities = F.softmax(room_output, dim=1)
area_pred_idx = torch.argmax(area_probabilities, dim=1).item()
room_pred_idx = torch.argmax(room_probabilities, dim=1).item()
predicted_area = get_area_from_bin(area_pred_idx)
predicted_rooms = room_pred_idx + 1 # Adjusting back to original room labels
return predicted_area, str(predicted_rooms)
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Floor Plan Image"), # Correct input
outputs=[
gr.Textbox(label="Predicted Total Area"), # Correct output
gr.Textbox(label="Predicted Number of Rooms")
],
title="Floor Plan Area and Room Predictor",
description="Upload a floor plan image, and the model will predict the total area range and the number of rooms."
)
# Launch Gradio interface
if __name__ == "__main__":
interface.launch()