Spaces:
Running
Running
Updated classifier
Browse files- app.py +3 -3
- models/flake_monolayer_classifier.pth +3 -0
app.py
CHANGED
|
@@ -13,7 +13,7 @@ from huggingface_hub import hf_hub_download
|
|
| 13 |
|
| 14 |
|
| 15 |
class FlakeLayerClassifier(nn.Module):
|
| 16 |
-
def __init__(self, num_materials, material_dim, num_classes=
|
| 17 |
super().__init__()
|
| 18 |
self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
|
| 19 |
if freeze_cnn:
|
|
@@ -72,8 +72,8 @@ print(f"Using device: {device}")
|
|
| 72 |
# Load YOLO detector
|
| 73 |
yolo = YOLO("models/uark_detector_v3.pt")
|
| 74 |
|
| 75 |
-
# Load classifier model checkpoint
|
| 76 |
-
ckpt_path = "models/
|
| 77 |
ckpt = torch.load(ckpt_path, map_location=device)
|
| 78 |
|
| 79 |
num_classes = len(ckpt["class_to_idx"])
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class FlakeLayerClassifier(nn.Module):
|
| 16 |
+
def __init__(self, num_materials, material_dim, num_classes=2, dropout_prob=0.1, freeze_cnn=False):
|
| 17 |
super().__init__()
|
| 18 |
self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
|
| 19 |
if freeze_cnn:
|
|
|
|
| 72 |
# Load YOLO detector
|
| 73 |
yolo = YOLO("models/uark_detector_v3.pt")
|
| 74 |
|
| 75 |
+
# Load classifier model checkpoint
|
| 76 |
+
ckpt_path = "models/flake_monolayer_classifier.pth"
|
| 77 |
ckpt = torch.load(ckpt_path, map_location=device)
|
| 78 |
|
| 79 |
num_classes = len(ckpt["class_to_idx"])
|
models/flake_monolayer_classifier.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc4e98bf4dd3127970ca7c68b633d29aa523a0a66a2d6481bf346d7662dbe7b8
|
| 3 |
+
size 47191055
|