Elfenreigen commited on
Commit
9da3725
·
verified ·
1 Parent(s): caa35e0

Upload 2 files

Browse files
Files changed (2) hide show
  1. readme.txt +15 -0
  2. test_example.py +293 -0
readme.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1. Model Weight Files and Dataset Configuration
2
+ Model Weight Files: The model weight files are located in the ./configs folder. Ensure that these files are correctly loaded when running the model.
3
+
4
+ CSV File: The CSV file contains the paths to the test images and their corresponding labels. This file is used for testing the model.
5
+
6
+ Modifying test_example.py:
7
+
8
+ In the Chestxray14_Dataset class within test_example.py, you need to adjust the column indices according to the structure of your CSV file. Specifically, ensure that the columns for image paths and labels are correctly referenced.
9
+
10
+ The valid_on method in test_example.py contains a text_list that should correspond to the labels in your CSV file. Make sure that the text_list is updated to match the labels in your dataset.
11
+
12
+ 2. Model Output
13
+ Prediction Results: The model will generate prediction results for each image and each class. These results will be saved as .npy files.
14
+
15
+ Output Location: The .npy files will be stored in subfolders within the ./results directory. Each subfolder corresponds to a specific run or configuration of the model.
test_example.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import os
4
+ import logging
5
+ import yaml
6
+ import numpy as np
7
+ import random
8
+ import time
9
+ import datetime
10
+ import json
11
+ import math
12
+ from pathlib import Path
13
+ from functools import partial
14
+ from collections import OrderedDict
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.backends.cudnn as cudnn
19
+ import torch.distributed as dist
20
+ from torch.utils.data import DataLoader
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from transformers import AutoModel, BertConfig, AutoTokenizer
23
+ from torch.utils.data import Dataset
24
+ from torchvision import transforms
25
+ import PIL
26
+ from PIL import Image
27
+ from models.clip_tqn import CLP_clinical, ModelRes, TQN_Model, ModelConvNeXt, ModelEfficientV2, ModelDense
28
+ import numpy as np
29
+ import pandas as pd
30
+ from factory import utils
31
+
32
+ class Chestxray14_Dataset(Dataset):
33
+ def __init__(self, csv_path,image_res):
34
+ data_info = pd.read_csv(csv_path)
35
+ self.img_path_list = np.asarray(data_info.iloc[:,0])
36
+ self.class_list = np.asarray(data_info.iloc[:,3:])
37
+
38
+ normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
39
+ self.transform = transforms.Compose([
40
+ transforms.Resize(image_res, interpolation=Image.BICUBIC),
41
+ transforms.ToTensor(),
42
+ normalize,
43
+ ])
44
+
45
+ def __getitem__(self, index):
46
+ img_path = self.img_path_list[index].replace('/mnt/petrelfs/zhangxiaoman/DATA/Chestxray/ChestXray8/','/remote-home/share/medical/public/ChestXray8/')
47
+ # revise according to the actual cirtumstances
48
+
49
+ class_label = self.class_list[index]
50
+ img = Image.open(img_path).convert('RGB')
51
+ image = self.transform(img)
52
+ return {
53
+ "img_path": img_path,
54
+ "image": image,
55
+ "label": class_label
56
+ }
57
+
58
+ def __len__(self):
59
+ return len(self.img_path_list)
60
+
61
+
62
+ def get_text_features(model,text_list,tokenizer,device,max_length):
63
+ text_token = tokenizer(list(text_list),add_special_tokens=True, padding='max_length', truncation=True, max_length= max_length, return_tensors="pt").to(device=device)
64
+ text_features = model.encode_text(text_token)
65
+ return text_features
66
+
67
+ def valid_on(model, image_encoder, text_encoder, tokenizer, data_loader, epoch, device, args, config, writer, total_test=False):
68
+ model.eval()
69
+ image_encoder.eval()
70
+ text_encoder.eval()
71
+
72
+ text_list = ["atelectasis","cardiomegaly","pleural effusion","infiltration","lung mass","lung nodule","pneumonia","pneumothorax","consolidation","edema","emphysema","fibrosis","pleural thicken","hernia"]
73
+
74
+ text_features = get_text_features(text_encoder,text_list,tokenizer,device,max_length=args.max_length)
75
+ device_num = torch.cuda.device_count()
76
+ text_features = text_features.repeat(int(device_num),1)
77
+
78
+ val_scalar_step = epoch*len(data_loader)
79
+ val_losses = []
80
+
81
+ gt = torch.FloatTensor()
82
+ gt = gt.cuda()
83
+ pred = torch.FloatTensor()
84
+ pred = pred.cuda()
85
+
86
+ for i, sample in enumerate(data_loader):
87
+ image = sample['image'].to(device,non_blocking=True)
88
+ label = sample['label'].long().to(device)
89
+ label = label.float()
90
+ gt = torch.cat((gt, label), 0)
91
+ with torch.no_grad():
92
+ image_features,image_features_pool = image_encoder(image)
93
+
94
+ pred_class = model(image_features,text_features)#b,14,2/1
95
+ val_loss = F.binary_cross_entropy_with_logits(pred_class.view(-1,1),label.view(-1, 1))
96
+ pred_class = torch.sigmoid(pred_class)
97
+ pred = torch.cat((pred, pred_class[:,:,0]), 0)
98
+
99
+
100
+ val_losses.append(val_loss.item())
101
+ writer.add_scalar('val_loss/loss', val_loss, val_scalar_step)
102
+ val_scalar_step += 1
103
+ gt_np = gt.cpu().numpy()
104
+ pred_np = pred.cpu().numpy()
105
+ np.save(f'{args.output_dir}/gt.npy', gt_np)
106
+ np.save(f'{args.output_dir}/pred.npy', pred_np)
107
+
108
+ return
109
+
110
+
111
+
112
+ def test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, device, args, config, writer, epoch=0, total_test=True):
113
+
114
+ valid_on(model, image_encoder, text_encoder, tokenizer, test_dataloader, epoch, device, args, config, writer, total_test=True)
115
+
116
+
117
+ def get_dataloader(args, config):
118
+ test_dataset = Chestxray14_Dataset(config['test_file'],config['image_res'])
119
+ test_dataloader = DataLoader(
120
+ test_dataset,
121
+ batch_size=config['test_batch_size'],
122
+ num_workers=config["test_num_workers"],
123
+ pin_memory=True,
124
+ collate_fn=None,
125
+ shuffle=False,
126
+ drop_last=False,
127
+ )
128
+
129
+ return test_dataloader, test_dataset
130
+
131
+
132
+ def get_model(args, config):
133
+ if 'resnet' in config['image_encoder_name']:
134
+ image_encoder = ModelRes(config['image_encoder_name']).cuda()
135
+ preprocess = None
136
+ elif 'convnext' in config['image_encoder_name']:
137
+ image_encoder = ModelConvNeXt(config['image_encoder_name']).cuda()
138
+ preprocess = None
139
+ elif 'efficientnet' in config['image_encoder_name']:
140
+ image_encoder = ModelEfficientV2(config['image_encoder_name']).cuda()
141
+ preprocess = None
142
+ elif 'densenet' in config['image_encoder_name']:
143
+ image_encoder = ModelDense(config['image_encoder_name']).cuda()
144
+ preprocess = None
145
+ else:
146
+ raise NotImplementedError(f"Unknown image encoder: {config['image_encoder_name']}")
147
+
148
+ tokenizer = AutoTokenizer.from_pretrained(args.bert_model_name, do_lower_case=True, local_files_only=True)
149
+ text_encoder = CLP_clinical(bert_model_name=args.bert_model_name).cuda()
150
+
151
+ if args.bert_pretrained:
152
+ checkpoint = torch.load(args.bert_pretrained, map_location='cpu')
153
+ state_dict = checkpoint["state_dict"]
154
+ text_encoder.load_state_dict(state_dict, strict=False)
155
+ if args.freeze_bert:
156
+ for param in text_encoder.parameters():
157
+ param.requires_grad = False
158
+
159
+ if 'lam' in config:
160
+ model = TQN_Model(class_num=args.class_num, lam=config['lam']).cuda()
161
+ else:
162
+ model = TQN_Model(class_num=args.class_num).cuda()
163
+
164
+ if args.distributed:
165
+ model = torch.nn.DataParallel(model)
166
+ image_encoder = torch.nn.DataParallel(image_encoder)
167
+
168
+ return model, image_encoder, text_encoder, tokenizer
169
+
170
+ def load_checkpoint(model, image_encoder, args):
171
+ if os.path.isfile(args.finetune):
172
+ checkpoint = torch.load(args.finetune, map_location='cpu')
173
+ image_state_dict = checkpoint['image_encoder']
174
+ new_image_state_dict = OrderedDict()
175
+ if 'module.' in list(image_encoder.state_dict().keys())[0] and 'module.' not in list(image_state_dict.keys())[0]:
176
+ for k, v in image_state_dict.items():
177
+ name = 'module.' + k
178
+ new_image_state_dict[name] = v
179
+ elif 'module.' not in list(image_encoder.state_dict().keys())[0] and 'module.' in list(image_state_dict.keys())[0]:
180
+ for k, v in image_state_dict.items():
181
+ name = k.replace('module.', '')
182
+ new_image_state_dict[name] = v
183
+ else:
184
+ new_image_state_dict = image_state_dict
185
+ image_encoder.load_state_dict(new_image_state_dict, strict=False)
186
+
187
+ state_dict = checkpoint['model']
188
+ new_state_dict = OrderedDict()
189
+ if 'module.' in list(model.state_dict().keys())[0] and 'module.' not in list(state_dict.keys())[0]:
190
+ for k, v in state_dict.items():
191
+ name = 'module.' + k
192
+ new_state_dict[name] = v
193
+ elif 'module.' not in list(model.state_dict().keys())[0] and 'module.' in list(state_dict.keys())[0]:
194
+ for k, v in state_dict.items():
195
+ name = k.replace('module.', '')
196
+ new_state_dict[name] = v
197
+ else:
198
+ new_state_dict = state_dict
199
+
200
+ model.load_state_dict(new_state_dict, strict=False)
201
+ print("load model success!")
202
+
203
+ def seed_torch(seed=42):
204
+ # if os.environ['LOCAL_RANK'] == '0':
205
+ print('=====> Using fixed random seed: ' + str(seed))
206
+ os.environ['PYTHONHASHSEED'] = str(seed)
207
+ random.seed(seed)
208
+ np.random.seed(seed)
209
+ torch.manual_seed(seed)
210
+ torch.cuda.manual_seed(seed)
211
+ torch.cuda.manual_seed_all(seed)
212
+ torch.backends.cudnn.deterministic = True
213
+ torch.backends.cudnn.benchmark = False
214
+
215
+
216
+ def main(args, config):
217
+ '''Data准备'''
218
+ test_dataloader, test_dataset = get_dataloader(args, config)
219
+
220
+ test_dataloader.num_samples = len(test_dataset)
221
+ test_dataloader.num_batches = len(test_dataset)
222
+
223
+
224
+ '''Model准备'''
225
+ model, image_encoder, text_encoder, tokenizer = get_model(args, config)
226
+
227
+ writer = SummaryWriter(os.path.join(args.output_dir, 'log'))
228
+
229
+ load_checkpoint(model, image_encoder, args)
230
+
231
+ test_all(model, image_encoder, text_encoder, tokenizer, test_dataloader, args.device, args, config, writer, epoch=0, total_test=True)
232
+
233
+
234
+
235
+ if __name__ == '__main__':
236
+ abs_file_path = os.path.abspath(__file__)
237
+ os.environ['OMP_NUM_THREADS'] = '1'
238
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
239
+ parser = argparse.ArgumentParser()
240
+ parser.add_argument('--momentum', default=False, type=bool)
241
+ parser.add_argument('--checkpoint', default='')
242
+ parser.add_argument('--finetune', default='base_checkpoint.pt')
243
+
244
+ parser.add_argument('--freeze_bert', default=True, type=bool)
245
+ parser.add_argument("--use_entity_features", default=True, type=bool)
246
+
247
+ parser.add_argument('--config', default='example.yaml')
248
+
249
+ parser.add_argument('--fourier', default=True, type=bool)
250
+ parser.add_argument('--colourjitter', default=True, type=bool)
251
+ parser.add_argument('--class_num', default=1, type=int) # FT1, FF2
252
+
253
+ parser.add_argument('--ignore_index', default=True, type=bool) #原始为false; +data时-1作为标记不算loss, ���成True
254
+ parser.add_argument('--add_dataset', default=False, type=bool)
255
+
256
+ time_now = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
257
+ parser.add_argument('--output_dir', default=f'./results/test-{time_now}')
258
+ parser.add_argument('--aws_output_dir', default=f'./results/test-{time_now}')
259
+ parser.add_argument('--bert_pretrained', default= './pretrained_bert_weights/epoch_latest.pt')
260
+ parser.add_argument('--bert_model_name', default= './pretrained_bert_weights/UMLSBert_ENG/')
261
+ parser.add_argument('--max_length', default=256, type=int)
262
+ parser.add_argument('--loss_ratio', default=1, type=int)
263
+ parser.add_argument('--device', default='cuda')
264
+ parser.add_argument('--seed', default=42, type=int)
265
+
266
+ # distributed training parameters
267
+ parser.add_argument("--local_rank", type=int)
268
+ parser.add_argument('--distributed', action='store_true', default=False, help='Use multi-processing distributed training to launch ')
269
+
270
+ parser.add_argument('--rho', default=0, type=float, help='gpu')
271
+ parser.add_argument('--gpu', default=0, type=int, help='gpu')
272
+ args = parser.parse_args()
273
+ args.config = f'./configs/{args.config}'
274
+
275
+
276
+
277
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
278
+ if config['finetune'] != '':
279
+ args.finetune = config['finetune']
280
+ args.checkpoint = config['finetune']
281
+
282
+ args.loss_ratio = config['loss_ratio']
283
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
284
+ Path(args.aws_output_dir).mkdir(parents=True, exist_ok=True)
285
+
286
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
287
+
288
+
289
+ seed_torch(args.seed)
290
+
291
+ main(args, config)
292
+
293
+