Upload 2 files
Browse files- readme.txt +15 -0
- test_example.py +293 -0
readme.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1. Model Weight Files and Dataset Configuration
|
2 |
+
Model Weight Files: The model weight files are located in the ./configs folder. Ensure that these files are correctly loaded when running the model.
|
3 |
+
|
4 |
+
CSV File: The CSV file contains the paths to the test images and their corresponding labels. This file is used for testing the model.
|
5 |
+
|
6 |
+
Modifying test_example.py:
|
7 |
+
|
8 |
+
In the Chestxray14_Dataset class within test_example.py, you need to adjust the column indices according to the structure of your CSV file. Specifically, ensure that the columns for image paths and labels are correctly referenced.
|
9 |
+
|
10 |
+
The valid_on method in test_example.py contains a text_list that should correspond to the labels in your CSV file. Make sure that the text_list is updated to match the labels in your dataset.
|
11 |
+
|
12 |
+
2. Model Output
|
13 |
+
Prediction Results: The model will generate prediction results for each image and each class. These results will be saved as .npy files.
|
14 |
+
|
15 |
+
Output Location: The .npy files will be stored in subfolders within the ./results directory. Each subfolder corresponds to a specific run or configuration of the model.
|
test_example.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
import yaml
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
import datetime
|
10 |
+
import json
|
11 |
+
import math
|
12 |
+
from pathlib import Path
|
13 |
+
from functools import partial
|
14 |
+
from collections import OrderedDict
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.backends.cudnn as cudnn
|
19 |
+
import torch.distributed as dist
|
20 |
+
from torch.utils.data import DataLoader
|
21 |
+
from torch.utils.tensorboard import SummaryWriter
|
22 |
+
from transformers import AutoModel, BertConfig, AutoTokenizer
|
23 |
+
from torch.utils.data import Dataset
|
24 |
+
from torchvision import transforms
|
25 |
+
import PIL
|
26 |
+
from PIL import Image
|
27 |
+
from models.clip_tqn import CLP_clinical, ModelRes, TQN_Model, ModelConvNeXt, ModelEfficientV2, ModelDense
|
28 |
+
import numpy as np
|
29 |
+
import pandas as pd
|
30 |
+
from factory import utils
|
31 |
+
|
32 |
+
class Chestxray14_Dataset(Dataset):
|
33 |
+
def __init__(self, csv_path,image_res):
|
34 |
+
data_info = pd.read_csv(csv_path)
|
35 |
+
self.img_path_list = np.asarray(data_info.iloc[:,0])
|
36 |
+
self.class_list = np.asarray(data_info.iloc[:,3:])
|
37 |
+
|
38 |
+
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
39 |
+
self.transform = transforms.Compose([
|
40 |
+
transforms.Resize(image_res, interpolation=Image.BICUBIC),
|
41 |
+
transforms.ToTensor(),
|
42 |
+
normalize,
|
43 |
+
])
|
44 |
+
|
45 |
+
def __getitem__(self, index):
|
46 |
+
img_path = self.img_path_list[index].replace('/mnt/petrelfs/zhangxiaoman/DATA/Chestxray/ChestXray8/','/remote-home/share/medical/public/ChestXray8/')
|
47 |
+
# revise according to the actual cirtumstances
|
48 |
+
|
49 |
+
class_label = self.class_list[index]
|
50 |
+
img = Image.open(img_path).convert('RGB')
|
51 |
+
image = self.transform(img)
|
52 |
+
return {
|
53 |
+
"img_path": img_path,
|
54 |
+
"image": image,
|
55 |
+
"label": class_label
|
56 |
+
}
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.img_path_list)
|
60 |
+
|
61 |
+
|
62 |
+
def get_text_features(model,text_list,tokenizer,device,max_length):
|
63 |
+
text_token = tokenizer(list(text_list),add_special_tokens=True, padding='max_length', truncation=True, max_length= max_length, return_tensors="pt").to(device=device)
|
64 |
+
text_features = model.encode_text(text_token)
|
65 |
+
return text_features
|
66 |
+
|
67 |
+
def valid_on(model, image_encoder, text_encoder, tokenizer, data_loader, epoch, device, args, config, writer, total_test=False):
|
68 |
+
model.eval()
|
69 |
+
image_encoder.eval()
|
70 |
+
text_encoder.eval()
|
71 |
+
|
72 |
+
text_list = ["atelectasis","cardiomegaly","pleural effusion","infiltration","lung mass","lung nodule","pneumonia","pneumothorax","consolidation","edema","emphysema","fibrosis","pleural thicken","hernia"]
|
73 |
+
|
74 |
+
text_features = get_text_features(text_encoder,text_list,tokenizer,device,max_length=args.max_length)
|
75 |
+
device_num = torch.cuda.device_count()
|
76 |
+
text_features = text_features.repeat(int(device_num),1)
|
77 |
+
|
78 |
+
val_scalar_step = epoch*len(data_loader)
|
79 |
+
val_losses = []
|
80 |
+
|
81 |
+
gt = torch.FloatTensor()
|
82 |
+
gt = gt.cuda()
|
83 |
+
pred = torch.FloatTensor()
|
84 |
+
pred = pred.cuda()
|
85 |
+
|
86 |
+
for i, sample in enumerate(data_loader):
|
87 |
+
image = sample['image'].to(device,non_blocking=True)
|
88 |
+
label = sample['label'].long().to(device)
|
89 |
+
label = label.float()
|
90 |
+
gt = torch.cat((gt, label), 0)
|
91 |
+
with torch.no_grad():
|
92 |
+
image_features,image_features_pool = image_encoder(image)
|
93 |
+
|
94 |
+
pred_class = model(image_features,text_features)#b,14,2/1
|
95 |
+
val_loss = F.binary_cross_entropy_with_logits(pred_class.view(-1,1),label.view(-1, 1))
|
96 |
+
pred_class = torch.sigmoid(pred_class)
|
97 |
+
pred = torch.cat((pred, pred_class[:,:,0]), 0)
|
98 |
+
|
99 |
+
|
100 |
+
val_losses.append(val_loss.item())
|
101 |
+
writer.add_scalar('val_loss/loss', val_loss, val_scalar_step)
|
102 |
+
val_scalar_step += 1
|
103 |
+
gt_np = gt.cpu().numpy()
|
104 |
+
pred_np = pred.cpu().numpy()
|
105 |
+
np.save(f'{args.output_dir}/gt.npy', gt_np)
|
106 |
+
np.save(f'{args.output_dir}/pred.npy', pred_np)
|
107 |
+
|
108 |
+
return
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
def test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, device, args, config, writer, epoch=0, total_test=True):
|
113 |
+
|
114 |
+
valid_on(model, image_encoder, text_encoder, tokenizer, test_dataloader, epoch, device, args, config, writer, total_test=True)
|
115 |
+
|
116 |
+
|
117 |
+
def get_dataloader(args, config):
|
118 |
+
test_dataset = Chestxray14_Dataset(config['test_file'],config['image_res'])
|
119 |
+
test_dataloader = DataLoader(
|
120 |
+
test_dataset,
|
121 |
+
batch_size=config['test_batch_size'],
|
122 |
+
num_workers=config["test_num_workers"],
|
123 |
+
pin_memory=True,
|
124 |
+
collate_fn=None,
|
125 |
+
shuffle=False,
|
126 |
+
drop_last=False,
|
127 |
+
)
|
128 |
+
|
129 |
+
return test_dataloader, test_dataset
|
130 |
+
|
131 |
+
|
132 |
+
def get_model(args, config):
|
133 |
+
if 'resnet' in config['image_encoder_name']:
|
134 |
+
image_encoder = ModelRes(config['image_encoder_name']).cuda()
|
135 |
+
preprocess = None
|
136 |
+
elif 'convnext' in config['image_encoder_name']:
|
137 |
+
image_encoder = ModelConvNeXt(config['image_encoder_name']).cuda()
|
138 |
+
preprocess = None
|
139 |
+
elif 'efficientnet' in config['image_encoder_name']:
|
140 |
+
image_encoder = ModelEfficientV2(config['image_encoder_name']).cuda()
|
141 |
+
preprocess = None
|
142 |
+
elif 'densenet' in config['image_encoder_name']:
|
143 |
+
image_encoder = ModelDense(config['image_encoder_name']).cuda()
|
144 |
+
preprocess = None
|
145 |
+
else:
|
146 |
+
raise NotImplementedError(f"Unknown image encoder: {config['image_encoder_name']}")
|
147 |
+
|
148 |
+
tokenizer = AutoTokenizer.from_pretrained(args.bert_model_name, do_lower_case=True, local_files_only=True)
|
149 |
+
text_encoder = CLP_clinical(bert_model_name=args.bert_model_name).cuda()
|
150 |
+
|
151 |
+
if args.bert_pretrained:
|
152 |
+
checkpoint = torch.load(args.bert_pretrained, map_location='cpu')
|
153 |
+
state_dict = checkpoint["state_dict"]
|
154 |
+
text_encoder.load_state_dict(state_dict, strict=False)
|
155 |
+
if args.freeze_bert:
|
156 |
+
for param in text_encoder.parameters():
|
157 |
+
param.requires_grad = False
|
158 |
+
|
159 |
+
if 'lam' in config:
|
160 |
+
model = TQN_Model(class_num=args.class_num, lam=config['lam']).cuda()
|
161 |
+
else:
|
162 |
+
model = TQN_Model(class_num=args.class_num).cuda()
|
163 |
+
|
164 |
+
if args.distributed:
|
165 |
+
model = torch.nn.DataParallel(model)
|
166 |
+
image_encoder = torch.nn.DataParallel(image_encoder)
|
167 |
+
|
168 |
+
return model, image_encoder, text_encoder, tokenizer
|
169 |
+
|
170 |
+
def load_checkpoint(model, image_encoder, args):
|
171 |
+
if os.path.isfile(args.finetune):
|
172 |
+
checkpoint = torch.load(args.finetune, map_location='cpu')
|
173 |
+
image_state_dict = checkpoint['image_encoder']
|
174 |
+
new_image_state_dict = OrderedDict()
|
175 |
+
if 'module.' in list(image_encoder.state_dict().keys())[0] and 'module.' not in list(image_state_dict.keys())[0]:
|
176 |
+
for k, v in image_state_dict.items():
|
177 |
+
name = 'module.' + k
|
178 |
+
new_image_state_dict[name] = v
|
179 |
+
elif 'module.' not in list(image_encoder.state_dict().keys())[0] and 'module.' in list(image_state_dict.keys())[0]:
|
180 |
+
for k, v in image_state_dict.items():
|
181 |
+
name = k.replace('module.', '')
|
182 |
+
new_image_state_dict[name] = v
|
183 |
+
else:
|
184 |
+
new_image_state_dict = image_state_dict
|
185 |
+
image_encoder.load_state_dict(new_image_state_dict, strict=False)
|
186 |
+
|
187 |
+
state_dict = checkpoint['model']
|
188 |
+
new_state_dict = OrderedDict()
|
189 |
+
if 'module.' in list(model.state_dict().keys())[0] and 'module.' not in list(state_dict.keys())[0]:
|
190 |
+
for k, v in state_dict.items():
|
191 |
+
name = 'module.' + k
|
192 |
+
new_state_dict[name] = v
|
193 |
+
elif 'module.' not in list(model.state_dict().keys())[0] and 'module.' in list(state_dict.keys())[0]:
|
194 |
+
for k, v in state_dict.items():
|
195 |
+
name = k.replace('module.', '')
|
196 |
+
new_state_dict[name] = v
|
197 |
+
else:
|
198 |
+
new_state_dict = state_dict
|
199 |
+
|
200 |
+
model.load_state_dict(new_state_dict, strict=False)
|
201 |
+
print("load model success!")
|
202 |
+
|
203 |
+
def seed_torch(seed=42):
|
204 |
+
# if os.environ['LOCAL_RANK'] == '0':
|
205 |
+
print('=====> Using fixed random seed: ' + str(seed))
|
206 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
207 |
+
random.seed(seed)
|
208 |
+
np.random.seed(seed)
|
209 |
+
torch.manual_seed(seed)
|
210 |
+
torch.cuda.manual_seed(seed)
|
211 |
+
torch.cuda.manual_seed_all(seed)
|
212 |
+
torch.backends.cudnn.deterministic = True
|
213 |
+
torch.backends.cudnn.benchmark = False
|
214 |
+
|
215 |
+
|
216 |
+
def main(args, config):
|
217 |
+
'''Data准备'''
|
218 |
+
test_dataloader, test_dataset = get_dataloader(args, config)
|
219 |
+
|
220 |
+
test_dataloader.num_samples = len(test_dataset)
|
221 |
+
test_dataloader.num_batches = len(test_dataset)
|
222 |
+
|
223 |
+
|
224 |
+
'''Model准备'''
|
225 |
+
model, image_encoder, text_encoder, tokenizer = get_model(args, config)
|
226 |
+
|
227 |
+
writer = SummaryWriter(os.path.join(args.output_dir, 'log'))
|
228 |
+
|
229 |
+
load_checkpoint(model, image_encoder, args)
|
230 |
+
|
231 |
+
test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, args.device, args, config, writer, epoch=0, total_test=True)
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == '__main__':
|
236 |
+
abs_file_path = os.path.abspath(__file__)
|
237 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
238 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
239 |
+
parser = argparse.ArgumentParser()
|
240 |
+
parser.add_argument('--momentum', default=False, type=bool)
|
241 |
+
parser.add_argument('--checkpoint', default='')
|
242 |
+
parser.add_argument('--finetune', default='base_checkpoint.pt')
|
243 |
+
|
244 |
+
parser.add_argument('--freeze_bert', default=True, type=bool)
|
245 |
+
parser.add_argument("--use_entity_features", default=True, type=bool)
|
246 |
+
|
247 |
+
parser.add_argument('--config', default='example.yaml')
|
248 |
+
|
249 |
+
parser.add_argument('--fourier', default=True, type=bool)
|
250 |
+
parser.add_argument('--colourjitter', default=True, type=bool)
|
251 |
+
parser.add_argument('--class_num', default=1, type=int) # FT1, FF2
|
252 |
+
|
253 |
+
parser.add_argument('--ignore_index', default=True, type=bool) #原始为false; +data时-1作为标记不算loss, ���成True
|
254 |
+
parser.add_argument('--add_dataset', default=False, type=bool)
|
255 |
+
|
256 |
+
time_now = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
|
257 |
+
parser.add_argument('--output_dir', default=f'./results/test-{time_now}')
|
258 |
+
parser.add_argument('--aws_output_dir', default=f'./results/test-{time_now}')
|
259 |
+
parser.add_argument('--bert_pretrained', default= './pretrained_bert_weights/epoch_latest.pt')
|
260 |
+
parser.add_argument('--bert_model_name', default= './pretrained_bert_weights/UMLSBert_ENG/')
|
261 |
+
parser.add_argument('--max_length', default=256, type=int)
|
262 |
+
parser.add_argument('--loss_ratio', default=1, type=int)
|
263 |
+
parser.add_argument('--device', default='cuda')
|
264 |
+
parser.add_argument('--seed', default=42, type=int)
|
265 |
+
|
266 |
+
# distributed training parameters
|
267 |
+
parser.add_argument("--local_rank", type=int)
|
268 |
+
parser.add_argument('--distributed', action='store_true', default=False, help='Use multi-processing distributed training to launch ')
|
269 |
+
|
270 |
+
parser.add_argument('--rho', default=0, type=float, help='gpu')
|
271 |
+
parser.add_argument('--gpu', default=0, type=int, help='gpu')
|
272 |
+
args = parser.parse_args()
|
273 |
+
args.config = f'./configs/{args.config}'
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
278 |
+
if config['finetune'] != '':
|
279 |
+
args.finetune = config['finetune']
|
280 |
+
args.checkpoint = config['finetune']
|
281 |
+
|
282 |
+
args.loss_ratio = config['loss_ratio']
|
283 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
284 |
+
Path(args.aws_output_dir).mkdir(parents=True, exist_ok=True)
|
285 |
+
|
286 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
287 |
+
|
288 |
+
|
289 |
+
seed_torch(args.seed)
|
290 |
+
|
291 |
+
main(args, config)
|
292 |
+
|
293 |
+
|