Spaces:
Build error
Build error
import argparse | |
import json | |
import os | |
from collections import defaultdict | |
from sklearn.metrics import log_loss | |
from torch import topk | |
import sys | |
print('@@@@@@@@@@@@@@@@@@') | |
sys.path.append('..') | |
from training import losses | |
from training.datasets.classifier_dataset import DeepFakeClassifierDataset | |
from training.losses import WeightedLosses | |
from training.tools.config import load_config | |
from training.tools.utils import create_optimizer, AverageMeter | |
from training.transforms.albu import IsotropicResize | |
from training.zoo import classifiers | |
os.environ["MKL_NUM_THREADS"] = "1" | |
os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
os.environ["OMP_NUM_THREADS"] = "1" | |
import cv2 | |
cv2.ocl.setUseOpenCL(False) | |
cv2.setNumThreads(0) | |
import numpy as np | |
from albumentations import Compose, RandomBrightnessContrast, \ | |
HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \ | |
ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur | |
from apex.parallel import DistributedDataParallel, convert_syncbn_model | |
from tensorboardX import SummaryWriter | |
from apex import amp | |
import torch | |
from torch.backends import cudnn | |
from torch.nn import DataParallel | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import torch.distributed as dist | |
torch.backends.cudnn.benchmark = True | |
def create_train_transforms(size=300): | |
return Compose([ | |
ImageCompression(quality_lower=60, quality_upper=100, p=0.5), | |
GaussNoise(p=0.1), | |
GaussianBlur(blur_limit=3, p=0.05), | |
HorizontalFlip(), | |
OneOf([ | |
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), | |
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR), | |
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR), | |
], p=1), | |
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT), | |
OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7), | |
ToGray(p=0.2), | |
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5), | |
] | |
) | |
def create_val_transforms(size=300): | |
return Compose([ | |
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC), | |
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT), | |
]) | |
def main(): | |
parser = argparse.ArgumentParser("PyTorch Xview Pipeline") | |
arg = parser.add_argument | |
arg('--config', metavar='CONFIG_FILE', help='path to configuration file') | |
arg('--workers', type=int, default=6, help='number of cpu threads to use') | |
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') | |
arg('--output-dir', type=str, default='weights/') | |
arg('--resume', type=str, default='') | |
arg('--fold', type=int, default=0) | |
arg('--prefix', type=str, default='classifier_') | |
arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake") | |
arg('--folds-csv', type=str, default='folds.csv') | |
arg('--crops-dir', type=str, default='crops') | |
arg('--label-smoothing', type=float, default=0.01) | |
arg('--logdir', type=str, default='logs') | |
arg('--zero-score', action='store_true', default=False) | |
arg('--from-zero', action='store_true', default=False) | |
arg('--distributed', action='store_true', default=False) | |
arg('--freeze-epochs', type=int, default=0) | |
arg("--local_rank", default=0, type=int) | |
arg("--seed", default=777, type=int) | |
arg("--padding-part", default=3, type=int) | |
arg("--opt-level", default='O1', type=str) | |
arg("--test_every", type=int, default=1) | |
arg("--no-oversample", action="store_true") | |
arg("--no-hardcore", action="store_true") | |
arg("--only-changed-frames", action="store_true") | |
args = parser.parse_args() | |
os.makedirs(args.output_dir, exist_ok=True) | |
if args.distributed: | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
else: | |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' | |
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | |
cudnn.benchmark = True | |
conf = load_config(args.config) | |
model = classifiers.__dict__[conf['network']](encoder=conf['encoder']) | |
model = model.cuda() | |
if args.distributed: | |
model = convert_syncbn_model(model) | |
ohem = conf.get("ohem_samples", None) | |
reduction = "mean" | |
if ohem: | |
reduction = "none" | |
loss_fn = [] | |
weights = [] | |
for loss_name, weight in conf["losses"].items(): | |
loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda()) | |
weights.append(weight) | |
loss = WeightedLosses(loss_fn, weights) | |
loss_functions = {"classifier_loss": loss} | |
optimizer, scheduler = create_optimizer(conf['optimizer'], model) | |
bce_best = 100 | |
start_epoch = 0 | |
batch_size = conf['optimizer']['batch_size'] | |
data_train = DeepFakeClassifierDataset(mode="train", | |
oversample_real=not args.no_oversample, | |
fold=args.fold, | |
padding_part=args.padding_part, | |
hardcore=not args.no_hardcore, | |
crops_dir=args.crops_dir, | |
data_path=args.data_dir, | |
label_smoothing=args.label_smoothing, | |
folds_csv=args.folds_csv, | |
transforms=create_train_transforms(conf["size"]), | |
normalize=conf.get("normalize", None)) | |
data_val = DeepFakeClassifierDataset(mode="val", | |
fold=args.fold, | |
padding_part=args.padding_part, | |
crops_dir=args.crops_dir, | |
data_path=args.data_dir, | |
folds_csv=args.folds_csv, | |
transforms=create_val_transforms(conf["size"]), | |
normalize=conf.get("normalize", None)) | |
val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False, | |
pin_memory=False) | |
os.makedirs(args.logdir, exist_ok=True) | |
summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold)) | |
if args.resume: | |
if os.path.isfile(args.resume): | |
print("=> loading checkpoint '{}'".format(args.resume)) | |
checkpoint = torch.load(args.resume, map_location='cpu') | |
state_dict = checkpoint['state_dict'] | |
state_dict = {k[7:]: w for k, w in state_dict.items()} | |
model.load_state_dict(state_dict, strict=False) | |
if not args.from_zero: | |
start_epoch = checkpoint['epoch'] | |
if not args.zero_score: | |
bce_best = checkpoint.get('bce_best', 0) | |
print("=> loaded checkpoint '{}' (epoch {}, bce_best {})" | |
.format(args.resume, checkpoint['epoch'], checkpoint['bce_best'])) | |
else: | |
print("=> no checkpoint found at '{}'".format(args.resume)) | |
if args.from_zero: | |
start_epoch = 0 | |
current_epoch = start_epoch | |
if conf['fp16']: | |
model, optimizer = amp.initialize(model, optimizer, | |
opt_level=args.opt_level, | |
loss_scale='dynamic') | |
snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold) | |
if args.distributed: | |
model = DistributedDataParallel(model, delay_allreduce=True) | |
else: | |
model = DataParallel(model).cuda() | |
data_val.reset(1, args.seed) | |
max_epochs = conf['optimizer']['schedule']['epochs'] | |
for epoch in range(start_epoch, max_epochs): | |
data_train.reset(epoch, args.seed) | |
train_sampler = None | |
if args.distributed: | |
train_sampler = torch.utils.data.distributed.DistributedSampler(data_train) | |
train_sampler.set_epoch(epoch) | |
if epoch < args.freeze_epochs: | |
print("Freezing encoder!!!") | |
model.module.encoder.eval() | |
for p in model.module.encoder.parameters(): | |
p.requires_grad = False | |
else: | |
model.module.encoder.train() | |
for p in model.module.encoder.parameters(): | |
p.requires_grad = True | |
train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers, | |
shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, | |
drop_last=True) | |
train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, | |
args.local_rank, args.only_changed_frames) | |
model = model.eval() | |
if args.local_rank == 0: | |
torch.save({ | |
'epoch': current_epoch + 1, | |
'state_dict': model.state_dict(), | |
'bce_best': bce_best, | |
}, args.output_dir + '/' + snapshot_name + "_last") | |
torch.save({ | |
'epoch': current_epoch + 1, | |
'state_dict': model.state_dict(), | |
'bce_best': bce_best, | |
}, args.output_dir + snapshot_name + "_{}".format(current_epoch)) | |
if (epoch + 1) % args.test_every == 0: | |
bce_best = evaluate_val(args, val_data_loader, bce_best, model, | |
snapshot_name=snapshot_name, | |
current_epoch=current_epoch, | |
summary_writer=summary_writer) | |
current_epoch += 1 | |
def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer): | |
print("Test phase") | |
model = model.eval() | |
bce, probs, targets = validate(model, data_loader=data_val) | |
if args.local_rank == 0: | |
summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch) | |
if bce < bce_best: | |
print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce)) | |
if args.output_dir is not None: | |
torch.save({ | |
'epoch': current_epoch + 1, | |
'state_dict': model.state_dict(), | |
'bce_best': bce, | |
}, args.output_dir + snapshot_name + "_best_dice") | |
bce_best = bce | |
with open("predictions_{}.json".format(args.fold), "w") as f: | |
json.dump({"probs": probs, "targets": targets}, f) | |
torch.save({ | |
'epoch': current_epoch + 1, | |
'state_dict': model.state_dict(), | |
'bce_best': bce_best, | |
}, args.output_dir + snapshot_name + "_last") | |
print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best)) | |
return bce_best | |
def validate(net, data_loader, prefix=""): | |
probs = defaultdict(list) | |
targets = defaultdict(list) | |
with torch.no_grad(): | |
for sample in tqdm(data_loader): | |
imgs = sample["image"].cuda() | |
img_names = sample["img_name"] | |
labels = sample["labels"].cuda().float() | |
out = net(imgs) | |
labels = labels.cpu().numpy() | |
preds = torch.sigmoid(out).cpu().numpy() | |
for i in range(out.shape[0]): | |
video, img_id = img_names[i].split("/") | |
probs[video].append(preds[i].tolist()) | |
targets[video].append(labels[i].tolist()) | |
data_x = [] | |
data_y = [] | |
for vid, score in probs.items(): | |
score = np.array(score) | |
lbl = targets[vid] | |
score = np.mean(score) | |
lbl = np.mean(lbl) | |
data_x.append(score) | |
data_y.append(lbl) | |
y = np.array(data_y) | |
x = np.array(data_x) | |
fake_idx = y > 0.1 | |
real_idx = y < 0.1 | |
fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1]) | |
real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1]) | |
print("{}fake_loss".format(prefix), fake_loss) | |
print("{}real_loss".format(prefix), real_loss) | |
return (fake_loss + real_loss) / 2, probs, targets | |
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, | |
local_rank, only_valid): | |
losses = AverageMeter() | |
fake_losses = AverageMeter() | |
real_losses = AverageMeter() | |
max_iters = conf["batches_per_epoch"] | |
print("training epoch {}".format(current_epoch)) | |
model.train() | |
pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0) | |
if conf["optimizer"]["schedule"]["mode"] == "epoch": | |
scheduler.step(current_epoch) | |
for i, sample in pbar: | |
imgs = sample["image"].cuda() | |
labels = sample["labels"].cuda().float() | |
out_labels = model(imgs) | |
if only_valid: | |
valid_idx = sample["valid"].cuda().float() > 0 | |
out_labels = out_labels[valid_idx] | |
labels = labels[valid_idx] | |
if labels.size(0) == 0: | |
continue | |
fake_loss = 0 | |
real_loss = 0 | |
fake_idx = labels > 0.5 | |
real_idx = labels <= 0.5 | |
ohem = conf.get("ohem_samples", None) | |
if torch.sum(fake_idx * 1) > 0: | |
fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx]) | |
if torch.sum(real_idx * 1) > 0: | |
real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx]) | |
if ohem: | |
fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean() | |
real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean() | |
loss = (fake_loss + real_loss) / 2 | |
losses.update(loss.item(), imgs.size(0)) | |
fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0)) | |
real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0)) | |
optimizer.zero_grad() | |
pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg, | |
"fake_loss": fake_losses.avg, "real_loss": real_losses.avg}) | |
if conf['fp16']: | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1) | |
optimizer.step() | |
torch.cuda.synchronize() | |
if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"): | |
scheduler.step(i + current_epoch * max_iters) | |
if i == max_iters - 1: | |
break | |
pbar.close() | |
if local_rank == 0: | |
for idx, param_group in enumerate(optimizer.param_groups): | |
lr = param_group['lr'] | |
summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch) | |
summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch) | |
if __name__ == '__main__': | |
main() | |