import argparse import json import os from collections import defaultdict from sklearn.metrics import log_loss from torch import topk import sys print('@@@@@@@@@@@@@@@@@@') sys.path.append('..') from training import losses from training.datasets.classifier_dataset import DeepFakeClassifierDataset from training.losses import WeightedLosses from training.tools.config import load_config from training.tools.utils import create_optimizer, AverageMeter from training.transforms.albu import IsotropicResize from training.zoo import classifiers os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" import cv2 cv2.ocl.setUseOpenCL(False) cv2.setNumThreads(0) import numpy as np from albumentations import Compose, RandomBrightnessContrast, \ HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \ ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur from apex.parallel import DistributedDataParallel, convert_syncbn_model from tensorboardX import SummaryWriter from apex import amp import torch from torch.backends import cudnn from torch.nn import DataParallel from torch.utils.data import DataLoader from tqdm import tqdm import torch.distributed as dist torch.backends.cudnn.benchmark = True def create_train_transforms(size=300): return Compose([ ImageCompression(quality_lower=60, quality_upper=100, p=0.5), GaussNoise(p=0.1), GaussianBlur(blur_limit=3, p=0.05), HorizontalFlip(), OneOf([ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR), IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR), ], p=1), PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT), OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7), ToGray(p=0.2), ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5), ] ) def create_val_transforms(size=300): return Compose([ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT), ]) def main(): parser = argparse.ArgumentParser("PyTorch Xview Pipeline") arg = parser.add_argument arg('--config', metavar='CONFIG_FILE', help='path to configuration file') arg('--workers', type=int, default=6, help='number of cpu threads to use') arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') arg('--output-dir', type=str, default='weights/') arg('--resume', type=str, default='') arg('--fold', type=int, default=0) arg('--prefix', type=str, default='classifier_') arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake") arg('--folds-csv', type=str, default='folds.csv') arg('--crops-dir', type=str, default='crops') arg('--label-smoothing', type=float, default=0.01) arg('--logdir', type=str, default='logs') arg('--zero-score', action='store_true', default=False) arg('--from-zero', action='store_true', default=False) arg('--distributed', action='store_true', default=False) arg('--freeze-epochs', type=int, default=0) arg("--local_rank", default=0, type=int) arg("--seed", default=777, type=int) arg("--padding-part", default=3, type=int) arg("--opt-level", default='O1', type=str) arg("--test_every", type=int, default=1) arg("--no-oversample", action="store_true") arg("--no-hardcore", action="store_true") arg("--only-changed-frames", action="store_true") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') else: os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cudnn.benchmark = True conf = load_config(args.config) model = classifiers.__dict__[conf['network']](encoder=conf['encoder']) model = model.cuda() if args.distributed: model = convert_syncbn_model(model) ohem = conf.get("ohem_samples", None) reduction = "mean" if ohem: reduction = "none" loss_fn = [] weights = [] for loss_name, weight in conf["losses"].items(): loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda()) weights.append(weight) loss = WeightedLosses(loss_fn, weights) loss_functions = {"classifier_loss": loss} optimizer, scheduler = create_optimizer(conf['optimizer'], model) bce_best = 100 start_epoch = 0 batch_size = conf['optimizer']['batch_size'] data_train = DeepFakeClassifierDataset(mode="train", oversample_real=not args.no_oversample, fold=args.fold, padding_part=args.padding_part, hardcore=not args.no_hardcore, crops_dir=args.crops_dir, data_path=args.data_dir, label_smoothing=args.label_smoothing, folds_csv=args.folds_csv, transforms=create_train_transforms(conf["size"]), normalize=conf.get("normalize", None)) data_val = DeepFakeClassifierDataset(mode="val", fold=args.fold, padding_part=args.padding_part, crops_dir=args.crops_dir, data_path=args.data_dir, folds_csv=args.folds_csv, transforms=create_val_transforms(conf["size"]), normalize=conf.get("normalize", None)) val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False, pin_memory=False) os.makedirs(args.logdir, exist_ok=True) summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold)) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') state_dict = checkpoint['state_dict'] state_dict = {k[7:]: w for k, w in state_dict.items()} model.load_state_dict(state_dict, strict=False) if not args.from_zero: start_epoch = checkpoint['epoch'] if not args.zero_score: bce_best = checkpoint.get('bce_best', 0) print("=> loaded checkpoint '{}' (epoch {}, bce_best {})" .format(args.resume, checkpoint['epoch'], checkpoint['bce_best'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.from_zero: start_epoch = 0 current_epoch = start_epoch if conf['fp16']: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale='dynamic') snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold) if args.distributed: model = DistributedDataParallel(model, delay_allreduce=True) else: model = DataParallel(model).cuda() data_val.reset(1, args.seed) max_epochs = conf['optimizer']['schedule']['epochs'] for epoch in range(start_epoch, max_epochs): data_train.reset(epoch, args.seed) train_sampler = None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(data_train) train_sampler.set_epoch(epoch) if epoch < args.freeze_epochs: print("Freezing encoder!!!") model.module.encoder.eval() for p in model.module.encoder.parameters(): p.requires_grad = False else: model.module.encoder.train() for p in model.module.encoder.parameters(): p.requires_grad = True train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers, shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, drop_last=True) train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, args.local_rank, args.only_changed_frames) model = model.eval() if args.local_rank == 0: torch.save({ 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + '/' + snapshot_name + "_last") torch.save({ 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + snapshot_name + "_{}".format(current_epoch)) if (epoch + 1) % args.test_every == 0: bce_best = evaluate_val(args, val_data_loader, bce_best, model, snapshot_name=snapshot_name, current_epoch=current_epoch, summary_writer=summary_writer) current_epoch += 1 def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer): print("Test phase") model = model.eval() bce, probs, targets = validate(model, data_loader=data_val) if args.local_rank == 0: summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch) if bce < bce_best: print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce)) if args.output_dir is not None: torch.save({ 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce, }, args.output_dir + snapshot_name + "_best_dice") bce_best = bce with open("predictions_{}.json".format(args.fold), "w") as f: json.dump({"probs": probs, "targets": targets}, f) torch.save({ 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + snapshot_name + "_last") print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best)) return bce_best def validate(net, data_loader, prefix=""): probs = defaultdict(list) targets = defaultdict(list) with torch.no_grad(): for sample in tqdm(data_loader): imgs = sample["image"].cuda() img_names = sample["img_name"] labels = sample["labels"].cuda().float() out = net(imgs) labels = labels.cpu().numpy() preds = torch.sigmoid(out).cpu().numpy() for i in range(out.shape[0]): video, img_id = img_names[i].split("/") probs[video].append(preds[i].tolist()) targets[video].append(labels[i].tolist()) data_x = [] data_y = [] for vid, score in probs.items(): score = np.array(score) lbl = targets[vid] score = np.mean(score) lbl = np.mean(lbl) data_x.append(score) data_y.append(lbl) y = np.array(data_y) x = np.array(data_x) fake_idx = y > 0.1 real_idx = y < 0.1 fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1]) real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1]) print("{}fake_loss".format(prefix), fake_loss) print("{}real_loss".format(prefix), real_loss) return (fake_loss + real_loss) / 2, probs, targets def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, local_rank, only_valid): losses = AverageMeter() fake_losses = AverageMeter() real_losses = AverageMeter() max_iters = conf["batches_per_epoch"] print("training epoch {}".format(current_epoch)) model.train() pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0) if conf["optimizer"]["schedule"]["mode"] == "epoch": scheduler.step(current_epoch) for i, sample in pbar: imgs = sample["image"].cuda() labels = sample["labels"].cuda().float() out_labels = model(imgs) if only_valid: valid_idx = sample["valid"].cuda().float() > 0 out_labels = out_labels[valid_idx] labels = labels[valid_idx] if labels.size(0) == 0: continue fake_loss = 0 real_loss = 0 fake_idx = labels > 0.5 real_idx = labels <= 0.5 ohem = conf.get("ohem_samples", None) if torch.sum(fake_idx * 1) > 0: fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx]) if torch.sum(real_idx * 1) > 0: real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx]) if ohem: fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean() real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean() loss = (fake_loss + real_loss) / 2 losses.update(loss.item(), imgs.size(0)) fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0)) real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0)) optimizer.zero_grad() pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg, "fake_loss": fake_losses.avg, "real_loss": real_losses.avg}) if conf['fp16']: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1) optimizer.step() torch.cuda.synchronize() if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"): scheduler.step(i + current_epoch * max_iters) if i == max_iters - 1: break pbar.close() if local_rank == 0: for idx, param_group in enumerate(optimizer.param_groups): lr = param_group['lr'] summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch) summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch) if __name__ == '__main__': main()