Haaribo's picture
Add application file
8d4ee22
from argparse import ArgumentParser
import torch
#from sparsification.glm_saga import glm_saga
from sparsification import feature_helpers
def safe_zip(*args):
for iterable in args[1:]:
if len(iterable) != len(args[0]):
print("Unequally sized iterables to zip, printing lengths")
for i, entry in enumerate(args):
print(i, len(entry))
raise ValueError("Unequally sized iterables to zip")
return zip(*args)
def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes,
):
print("Computing/loading deep features...")
Ntotal = len(train_loader.dataset)
feature_loaders = {}
# Compute Features for not augmented train and test set
train_loader_transforms = train_loader.dataset.transform
test_loader_transforms = test_loader.dataset.transform
train_loader.dataset.transform = test_loader_transforms
for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): #
print(f"For {mode} set...")
sink_path = f"{out_dir_feats}/features_{mode}"
metadata_path = f"{out_dir_feats}/metadata_{mode}.pth"
feature_ds, feature_loader = feature_helpers.compute_features(loader,
model,
dataset_type=args.dataset_type,
pooled_output=None,
batch_size=args.batch_size,
num_workers=0, # args.num_workers,
shuffle=(mode == 'test'),
device=args.device,
filename=sink_path, n_epoch=1,
balance=False,
) # args.balance if mode == 'test' else False)
if mode == 'train':
metadata = feature_helpers.calculate_metadata(feature_loader,
num_classes=num_classes,
filename=metadata_path)
if metadata["max_reg"]["group"] == 0.0:
return None, False
split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds,
Ntotal,
val_frac=args.val_frac,
batch_size=args.batch_size,
num_workers=args.num_workers,
random_seed=args.random_seed,
shuffle=True,
balance=False)
feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi])
for mi, mm in enumerate(['train', 'val'])})
else:
feature_loaders[mode] = feature_loader
train_loader.dataset.transform = train_loader_transforms
return feature_loaders, metadata
def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ):
args = get_default_args()
args.random_seed = seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_folder = log_folder / "features"
feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model,
feature_folder
,
num_classes,
)
return feature_loaders, metadata, device,args
def add_index_to_dataloader(loader, sample_weight=None,):
return torch.utils.data.DataLoader(
IndexedDataset(loader.dataset, sample_weight=sample_weight),
batch_size=loader.batch_size,
sampler=loader.sampler,
num_workers=loader.num_workers,
collate_fn=loader.collate_fn,
pin_memory=loader.pin_memory,
drop_last=loader.drop_last,
timeout=loader.timeout,
worker_init_fn=loader.worker_init_fn,
multiprocessing_context=loader.multiprocessing_context
)
class IndexedDataset(torch.utils.data.Dataset):
def __init__(self, ds, sample_weight=None):
super(torch.utils.data.Dataset, self).__init__()
self.dataset = ds
self.sample_weight = sample_weight
def __getitem__(self, index):
val = self.dataset[index]
if self.sample_weight is None:
return val + (index,)
else:
weight = self.sample_weight[index]
return val + (weight, index)
def __len__(self):
return len(self.dataset)
def get_default_args():
# Default args from glm_saga, https://github.com/MadryLab/glm_saga
parser = ArgumentParser()
parser.add_argument('--dataset', type=str, help='dataset name')
parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]')
parser.add_argument('--dataset-path', type=str, help='path to dataset')
parser.add_argument('--model-path', type=str, help='path to model checkpoint')
parser.add_argument('--arch', type=str, help='model architecture type')
parser.add_argument('--out-path', help='location for saving results')
parser.add_argument('--cache', action='store_true', help='cache deep features')
parser.add_argument('--balance', action='store_true', help='balance classes for evaluation')
parser.add_argument('--device', default='cuda')
parser.add_argument('--random-seed', default=0)
parser.add_argument('--num-workers', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--val-frac', type=float, default=0.1)
parser.add_argument('--lr-decay-factor', type=float, default=1)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--alpha', type=float, default=0.99)
parser.add_argument('--max-epochs', type=int, default=2000)
parser.add_argument('--verbose', type=int, default=200)
parser.add_argument('--tol', type=float, default=1e-4)
parser.add_argument('--lookbehind', type=int, default=3)
parser.add_argument('--lam-factor', type=float, default=0.001)
parser.add_argument('--group', action='store_true')
args = parser.parse_args()
args = parser.parse_args()
return args
def select_in_loader(feature_loaders, feature_selection):
for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train
tensors = list(dataset.tensors)
if tensors[0].shape[1] == len(feature_selection):
continue
tensors[0] = tensors[0][:, feature_selection]
dataset.tensors = tensors
for dataset in feature_loaders["test"].dataset.datasets:
tensors = list(dataset.tensors)
if tensors[0].shape[1] == len(feature_selection):
continue
tensors[0] = tensors[0][:, feature_selection]
dataset.tensors = tensors
return feature_loaders