fixing shape
Browse files
run.sh
CHANGED
File without changes
|
test.py
CHANGED
@@ -10,6 +10,7 @@ import argparse
|
|
10 |
import json
|
11 |
import os
|
12 |
import torch
|
|
|
13 |
# from torch.utils.data import Subset
|
14 |
# from scipy.ndimage import gaussian_filter
|
15 |
# import cv2
|
@@ -17,15 +18,20 @@ import numpy as np
|
|
17 |
|
18 |
# Importing from local modules
|
19 |
from tools import write2csv, setup_seed, Logger
|
|
|
20 |
# from dataset import get_data, dataset_dict
|
21 |
from method import AdaCLIP_Trainer
|
22 |
from PIL import Image
|
23 |
import numpy as np
|
24 |
-
from datasets.rayan_dataset import RayanDataset
|
25 |
from utils.dump_scores import DumpScores
|
|
|
26 |
|
27 |
setup_seed(111)
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
def get_available_class_names(data_path):
|
31 |
all_items = os.listdir(data_path)
|
@@ -41,7 +47,7 @@ def train(args):
|
|
41 |
args.ckt_path
|
42 |
), f"Please check the path of pre-trained model, {args.ckt_path} is not valid."
|
43 |
|
44 |
-
data_path = "./data/"
|
45 |
class_names = get_available_class_names(data_path)
|
46 |
|
47 |
# Configurations
|
@@ -97,15 +103,23 @@ def train(args):
|
|
97 |
csv_path = os.path.join(csv_root, f"{args.testing_data}.csv")
|
98 |
image_dir = os.path.join(image_root, f"{args.testing_data}")
|
99 |
os.makedirs(image_dir, exist_ok=True)
|
100 |
-
|
101 |
dumper = DumpScores()
|
102 |
|
103 |
for classname in class_names:
|
104 |
test_data = RayanDataset(
|
105 |
source=data_path,
|
106 |
classname=classname,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
)
|
108 |
-
|
109 |
test_dataloader = torch.utils.data.DataLoader(
|
110 |
test_data, batch_size=1, shuffle=False
|
111 |
)
|
|
|
10 |
import json
|
11 |
import os
|
12 |
import torch
|
13 |
+
|
14 |
# from torch.utils.data import Subset
|
15 |
# from scipy.ndimage import gaussian_filter
|
16 |
# import cv2
|
|
|
18 |
|
19 |
# Importing from local modules
|
20 |
from tools import write2csv, setup_seed, Logger
|
21 |
+
|
22 |
# from dataset import get_data, dataset_dict
|
23 |
from method import AdaCLIP_Trainer
|
24 |
from PIL import Image
|
25 |
import numpy as np
|
26 |
+
from datasets.rayan_dataset import RayanDataset
|
27 |
from utils.dump_scores import DumpScores
|
28 |
+
from torchvision import transforms
|
29 |
|
30 |
setup_seed(111)
|
31 |
|
32 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
33 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
34 |
+
|
35 |
|
36 |
def get_available_class_names(data_path):
|
37 |
all_items = os.listdir(data_path)
|
|
|
47 |
args.ckt_path
|
48 |
), f"Please check the path of pre-trained model, {args.ckt_path} is not valid."
|
49 |
|
50 |
+
data_path = "./data/"
|
51 |
class_names = get_available_class_names(data_path)
|
52 |
|
53 |
# Configurations
|
|
|
103 |
csv_path = os.path.join(csv_root, f"{args.testing_data}.csv")
|
104 |
image_dir = os.path.join(image_root, f"{args.testing_data}")
|
105 |
os.makedirs(image_dir, exist_ok=True)
|
106 |
+
|
107 |
dumper = DumpScores()
|
108 |
|
109 |
for classname in class_names:
|
110 |
test_data = RayanDataset(
|
111 |
source=data_path,
|
112 |
classname=classname,
|
113 |
+
external_transform=transforms.Compose(
|
114 |
+
[
|
115 |
+
transforms.Resize((224, 224)),
|
116 |
+
transforms.CenterCrop(224),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
119 |
+
]
|
120 |
+
),
|
121 |
)
|
122 |
+
|
123 |
test_dataloader = torch.utils.data.DataLoader(
|
124 |
test_data, batch_size=1, shuffle=False
|
125 |
)
|