SciOner commited on
Commit
8601eb3
·
1 Parent(s): c032d2b

init commit

Browse files
Files changed (3) hide show
  1. app.py +93 -0
  2. floorplan_model_classification.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ import torchvision.transforms as transforms
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import gradio as gr
8
+
9
+ # Define your model architecture
10
+ class EfficientNetMultiTask(nn.Module):
11
+ def __init__(self, n_area_classes, n_room_classes):
12
+ super(EfficientNetMultiTask, self).__init__()
13
+ self.efficientnet = models.efficientnet_b0(pretrained=False)
14
+ in_features = self.efficientnet.classifier[1].in_features
15
+ self.area_classifier = nn.Sequential(
16
+ nn.Linear(in_features, 512),
17
+ nn.ReLU(),
18
+ nn.Dropout(0.3),
19
+ nn.Linear(512, n_area_classes)
20
+ )
21
+ self.room_classifier = nn.Sequential(
22
+ nn.Linear(in_features, 512),
23
+ nn.ReLU(),
24
+ nn.Dropout(0.3),
25
+ nn.Linear(512, n_room_classes)
26
+ )
27
+ self.efficientnet.classifier = nn.Identity()
28
+
29
+ def forward(self, x):
30
+ features = self.efficientnet(x)
31
+ area_pred = self.area_classifier(features)
32
+ room_pred = self.room_classifier(features)
33
+ return area_pred, room_pred
34
+
35
+
36
+ # Load model
37
+ n_area_classes = 21 # Adjust according to your area bins
38
+ n_room_classes = 16 # Adjust based on your dataset
39
+ model = EfficientNetMultiTask(n_area_classes=n_area_classes, n_room_classes=n_room_classes)
40
+
41
+ # Load weights (ensure floorplan_model_classification.pth is in the same directory as app.py)
42
+ model_weights_path = 'floorplan_model_classification.pth' # Adjust with your model weights path
43
+ model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
44
+ model.eval()
45
+
46
+ # Define transformations
47
+ test_transform = transforms.Compose([
48
+ transforms.Resize((224, 224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
+ ])
52
+
53
+ # Define area bins
54
+ area_bins = [i for i in range(0, 525, 25)] # [0, 25, ..., 500]
55
+ area_bins.append(float('inf')) # Add infinity for 500+ category
56
+
57
+ def get_area_from_bin(area_bin_idx):
58
+ if area_bin_idx < len(area_bins) - 2:
59
+ return f"{area_bins[area_bin_idx]} - {area_bins[area_bin_idx + 1]} m²"
60
+ else:
61
+ return f"{area_bins[-2]}+ m²"
62
+
63
+ # Prediction function
64
+ def predict(image):
65
+ image = Image.fromarray(image).convert('RGB')
66
+ image = test_transform(image).unsqueeze(0)
67
+
68
+ with torch.no_grad():
69
+ area_output, room_output = model(image)
70
+ area_probabilities = F.softmax(area_output, dim=1)
71
+ room_probabilities = F.softmax(room_output, dim=1)
72
+ area_pred_idx = torch.argmax(area_probabilities, dim=1).item()
73
+ room_pred_idx = torch.argmax(room_probabilities, dim=1).item()
74
+ predicted_area = get_area_from_bin(area_pred_idx)
75
+ predicted_rooms = room_pred_idx + 1 # Adjusting back to original room labels
76
+
77
+ return predicted_area, str(predicted_rooms)
78
+
79
+ # Gradio interface
80
+ interface = gr.Interface(
81
+ fn=predict,
82
+ inputs=gr.Image(type="numpy", label="Upload Floor Plan Image"), # Correct input
83
+ outputs=[
84
+ gr.Textbox(label="Predicted Total Area"), # Correct output
85
+ gr.Textbox(label="Predicted Number of Rooms")
86
+ ],
87
+ title="Floor Plan Area and Room Predictor",
88
+ description="Upload a floor plan image, and the model will predict the total area range and the number of rooms."
89
+ )
90
+
91
+ # Launch Gradio interface
92
+ if __name__ == "__main__":
93
+ interface.launch()
floorplan_model_classification.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a37c5986f0e9b486c8b2db8032630c581ee1b444e8cd80f7ace1f83aef6f1bb
3
+ size 21671252
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ gradio