smjfas commited on
Commit
b6717f0
·
1 Parent(s): ed76126

fixing shape

Browse files
Files changed (2) hide show
  1. run.sh +0 -0
  2. test.py +18 -4
run.sh CHANGED
File without changes
test.py CHANGED
@@ -10,6 +10,7 @@ import argparse
10
  import json
11
  import os
12
  import torch
 
13
  # from torch.utils.data import Subset
14
  # from scipy.ndimage import gaussian_filter
15
  # import cv2
@@ -17,15 +18,20 @@ import numpy as np
17
 
18
  # Importing from local modules
19
  from tools import write2csv, setup_seed, Logger
 
20
  # from dataset import get_data, dataset_dict
21
  from method import AdaCLIP_Trainer
22
  from PIL import Image
23
  import numpy as np
24
- from datasets.rayan_dataset import RayanDataset
25
  from utils.dump_scores import DumpScores
 
26
 
27
  setup_seed(111)
28
 
 
 
 
29
 
30
  def get_available_class_names(data_path):
31
  all_items = os.listdir(data_path)
@@ -41,7 +47,7 @@ def train(args):
41
  args.ckt_path
42
  ), f"Please check the path of pre-trained model, {args.ckt_path} is not valid."
43
 
44
- data_path = "./data/"
45
  class_names = get_available_class_names(data_path)
46
 
47
  # Configurations
@@ -97,15 +103,23 @@ def train(args):
97
  csv_path = os.path.join(csv_root, f"{args.testing_data}.csv")
98
  image_dir = os.path.join(image_root, f"{args.testing_data}")
99
  os.makedirs(image_dir, exist_ok=True)
100
-
101
  dumper = DumpScores()
102
 
103
  for classname in class_names:
104
  test_data = RayanDataset(
105
  source=data_path,
106
  classname=classname,
 
 
 
 
 
 
 
 
107
  )
108
-
109
  test_dataloader = torch.utils.data.DataLoader(
110
  test_data, batch_size=1, shuffle=False
111
  )
 
10
  import json
11
  import os
12
  import torch
13
+
14
  # from torch.utils.data import Subset
15
  # from scipy.ndimage import gaussian_filter
16
  # import cv2
 
18
 
19
  # Importing from local modules
20
  from tools import write2csv, setup_seed, Logger
21
+
22
  # from dataset import get_data, dataset_dict
23
  from method import AdaCLIP_Trainer
24
  from PIL import Image
25
  import numpy as np
26
+ from datasets.rayan_dataset import RayanDataset
27
  from utils.dump_scores import DumpScores
28
+ from torchvision import transforms
29
 
30
  setup_seed(111)
31
 
32
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
33
+ IMAGENET_STD = [0.229, 0.224, 0.225]
34
+
35
 
36
  def get_available_class_names(data_path):
37
  all_items = os.listdir(data_path)
 
47
  args.ckt_path
48
  ), f"Please check the path of pre-trained model, {args.ckt_path} is not valid."
49
 
50
+ data_path = "./data/"
51
  class_names = get_available_class_names(data_path)
52
 
53
  # Configurations
 
103
  csv_path = os.path.join(csv_root, f"{args.testing_data}.csv")
104
  image_dir = os.path.join(image_root, f"{args.testing_data}")
105
  os.makedirs(image_dir, exist_ok=True)
106
+
107
  dumper = DumpScores()
108
 
109
  for classname in class_names:
110
  test_data = RayanDataset(
111
  source=data_path,
112
  classname=classname,
113
+ external_transform=transforms.Compose(
114
+ [
115
+ transforms.Resize((224, 224)),
116
+ transforms.CenterCrop(224),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
119
+ ]
120
+ ),
121
  )
122
+
123
  test_dataloader = torch.utils.data.DataLoader(
124
  test_data, batch_size=1, shuffle=False
125
  )