Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
import torch | |
#from sparsification.glm_saga import glm_saga | |
from sparsification import feature_helpers | |
def safe_zip(*args): | |
for iterable in args[1:]: | |
if len(iterable) != len(args[0]): | |
print("Unequally sized iterables to zip, printing lengths") | |
for i, entry in enumerate(args): | |
print(i, len(entry)) | |
raise ValueError("Unequally sized iterables to zip") | |
return zip(*args) | |
def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes, | |
): | |
print("Computing/loading deep features...") | |
Ntotal = len(train_loader.dataset) | |
feature_loaders = {} | |
# Compute Features for not augmented train and test set | |
train_loader_transforms = train_loader.dataset.transform | |
test_loader_transforms = test_loader.dataset.transform | |
train_loader.dataset.transform = test_loader_transforms | |
for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): # | |
print(f"For {mode} set...") | |
sink_path = f"{out_dir_feats}/features_{mode}" | |
metadata_path = f"{out_dir_feats}/metadata_{mode}.pth" | |
feature_ds, feature_loader = feature_helpers.compute_features(loader, | |
model, | |
dataset_type=args.dataset_type, | |
pooled_output=None, | |
batch_size=args.batch_size, | |
num_workers=0, # args.num_workers, | |
shuffle=(mode == 'test'), | |
device=args.device, | |
filename=sink_path, n_epoch=1, | |
balance=False, | |
) # args.balance if mode == 'test' else False) | |
if mode == 'train': | |
metadata = feature_helpers.calculate_metadata(feature_loader, | |
num_classes=num_classes, | |
filename=metadata_path) | |
if metadata["max_reg"]["group"] == 0.0: | |
return None, False | |
split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds, | |
Ntotal, | |
val_frac=args.val_frac, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
random_seed=args.random_seed, | |
shuffle=True, | |
balance=False) | |
feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi]) | |
for mi, mm in enumerate(['train', 'val'])}) | |
else: | |
feature_loaders[mode] = feature_loader | |
train_loader.dataset.transform = train_loader_transforms | |
return feature_loaders, metadata | |
def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ): | |
args = get_default_args() | |
args.random_seed = seed | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
feature_folder = log_folder / "features" | |
feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model, | |
feature_folder | |
, | |
num_classes, | |
) | |
return feature_loaders, metadata, device,args | |
def add_index_to_dataloader(loader, sample_weight=None,): | |
return torch.utils.data.DataLoader( | |
IndexedDataset(loader.dataset, sample_weight=sample_weight), | |
batch_size=loader.batch_size, | |
sampler=loader.sampler, | |
num_workers=loader.num_workers, | |
collate_fn=loader.collate_fn, | |
pin_memory=loader.pin_memory, | |
drop_last=loader.drop_last, | |
timeout=loader.timeout, | |
worker_init_fn=loader.worker_init_fn, | |
multiprocessing_context=loader.multiprocessing_context | |
) | |
class IndexedDataset(torch.utils.data.Dataset): | |
def __init__(self, ds, sample_weight=None): | |
super(torch.utils.data.Dataset, self).__init__() | |
self.dataset = ds | |
self.sample_weight = sample_weight | |
def __getitem__(self, index): | |
val = self.dataset[index] | |
if self.sample_weight is None: | |
return val + (index,) | |
else: | |
weight = self.sample_weight[index] | |
return val + (weight, index) | |
def __len__(self): | |
return len(self.dataset) | |
def get_default_args(): | |
# Default args from glm_saga, https://github.com/MadryLab/glm_saga | |
parser = ArgumentParser() | |
parser.add_argument('--dataset', type=str, help='dataset name') | |
parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]') | |
parser.add_argument('--dataset-path', type=str, help='path to dataset') | |
parser.add_argument('--model-path', type=str, help='path to model checkpoint') | |
parser.add_argument('--arch', type=str, help='model architecture type') | |
parser.add_argument('--out-path', help='location for saving results') | |
parser.add_argument('--cache', action='store_true', help='cache deep features') | |
parser.add_argument('--balance', action='store_true', help='balance classes for evaluation') | |
parser.add_argument('--device', default='cuda') | |
parser.add_argument('--random-seed', default=0) | |
parser.add_argument('--num-workers', type=int, default=2) | |
parser.add_argument('--batch-size', type=int, default=256) | |
parser.add_argument('--val-frac', type=float, default=0.1) | |
parser.add_argument('--lr-decay-factor', type=float, default=1) | |
parser.add_argument('--lr', type=float, default=0.1) | |
parser.add_argument('--alpha', type=float, default=0.99) | |
parser.add_argument('--max-epochs', type=int, default=2000) | |
parser.add_argument('--verbose', type=int, default=200) | |
parser.add_argument('--tol', type=float, default=1e-4) | |
parser.add_argument('--lookbehind', type=int, default=3) | |
parser.add_argument('--lam-factor', type=float, default=0.001) | |
parser.add_argument('--group', action='store_true') | |
args = parser.parse_args() | |
args = parser.parse_args() | |
return args | |
def select_in_loader(feature_loaders, feature_selection): | |
for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train | |
tensors = list(dataset.tensors) | |
if tensors[0].shape[1] == len(feature_selection): | |
continue | |
tensors[0] = tensors[0][:, feature_selection] | |
dataset.tensors = tensors | |
for dataset in feature_loaders["test"].dataset.datasets: | |
tensors = list(dataset.tensors) | |
if tensors[0].shape[1] == len(feature_selection): | |
continue | |
tensors[0] = tensors[0][:, feature_selection] | |
dataset.tensors = tensors | |
return feature_loaders | |