import torch import torch.nn as nn import numpy as np import argparse import os from torch.optim import Adam, lr_scheduler from torch.utils.data import DataLoader from torch.utils.data.sampler import SubsetRandomSampler from batch_image_transforms import batch_transform_train, batch_transform_val from i2v import i2v_transform, EMBED_DIM, load_weight from data import DatasetFolder, get_target_transform from model import ScoringModel from metric import binary_accuracy, precision, recall from torch.utils.tensorboard import SummaryWriter device = 'cuda' if torch.cuda.is_available() else 'cpu' CWD = os.path.dirname(__file__) TRAINING_P = 0.8 CHECKPOINTS_DIRECTORY=os.path.join(CWD, '..', 'checkpoints') WEIGHT_i2v = os.path.join(CWD, '..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt') i2v = load_weight(WEIGHT_i2v) i2v.to(device) i2v_transform.to(device) def save_checkpoint(dict_to_save: dict, path: str): directory = os.path.dirname(path) if not os.path.exists(directory): os.makedirs(directory) torch.save(dict_to_save, path) def write_tb_logs(stored_value, epoch, folder:str= "Loss/train"): if writer is not None: writer.add_scalar(folder, stored_value, epoch) writer.flush() def train(epoch): scoringModel.train() total_n, total_loss, total_acc = 0, 0.0, 0.0 total_precision, total_recall = 0.0, 0.0 for index, (x,y_true) in enumerate(train_loader): x = x.to(device) # (batch_size, frames_per_clip, 3, 224, 224) y_true = y_true.to(device).float() batch_size = x.size(0) x = x.view(-1, 3, 224, 224) # (batch_size * frames_per_clip, 3, 224 ,224) with torch.no_grad(): x, _ = i2v(i2v_transform(x)) # (batch_size * frames_per_clip, 512) x = x.view(batch_size, -1, EMBED_DIM) # (batch_size, frames_per_clip, 512) if args.frame_diff: x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] x = x.view(batch_size, -1) y_pred = scoringModel(x).view(-1) acc = binary_accuracy(y_pred, y_true) pr = precision(y_pred, y_true) rc = recall(y_pred, y_true) loss = loss_fn(y_pred, y_true) optimizer.zero_grad() loss.backward() optimizer.step() if index % 10 == 0: print('Loss[Epoch:{}, Iteration:{}] = {:.8f}'.format(epoch, index, loss.item())) total_acc += acc * batch_size total_precision += pr total_recall += rc total_loss += loss.item() * batch_size total_n += batch_size avg_loss = total_loss / total_n avg_acc = total_acc / total_n avg_precision = total_precision / len(train_loader) avg_recall = total_recall / len(train_loader) print('Finished training for epoch {} with average loss/acc: {}/{}'.format(epoch, avg_loss, avg_acc)) write_tb_logs(avg_loss, epoch, "Loss/train") write_tb_logs(avg_acc, epoch, "Accuracy/train") write_tb_logs(avg_precision, epoch, "Precision/train") write_tb_logs(avg_recall, epoch, "Recall/train") return avg_loss def validate(epoch): scoringModel.eval() if epoch: print('Calculating validation loss for epoch', epoch) else: print('Calculating validation loss') total_n, total_loss, total_acc = 0, 0.0, 0.0 total_precision, total_recall = 0.0, 0.0 for x, y_true in val_loader: x = x.to(device) # (batch_size, frames_per_clip, 3, 224, 224) y_true = y_true.to(device).float() batch_size = x.size(0) x = x.view(-1, 3, 224, 224) # (batch_size * frames_per_clip, 3, 224 ,224) with torch.no_grad(): x, _ = i2v(i2v_transform(x)) # (batch_size * frames_per_clip, 512) x = x.view(batch_size, -1, EMBED_DIM) # (batch_size, frames_per_clip, 512) if args.frame_diff: x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] x = x.view(batch_size, -1) y_pred = scoringModel(x).view(-1) loss = loss_fn(y_pred, y_true) acc = binary_accuracy(y_pred, y_true) pr = precision(y_pred, y_true) rc = recall(y_pred, y_true) total_acc += acc * batch_size total_precision += pr total_recall += rc total_loss += loss.item() * batch_size total_n += batch_size avg_loss = total_loss / total_n avg_acc = total_acc / total_n avg_precision = total_precision / len(val_loader) avg_recall = total_recall / len(val_loader) print('Finished calculating validation loss/acc: {}/{}'.format(avg_loss, avg_acc)) write_tb_logs(avg_loss, epoch, "Loss/val") write_tb_logs(avg_acc, epoch, "Accuracy/val") write_tb_logs(avg_precision, epoch, "Precision/val") write_tb_logs(avg_recall, epoch, "Recall/val") return avg_loss if __name__ == '__main__': parser = argparse.ArgumentParser() ### Training related arguments ### parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--n_epochs', type=int, default=1) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--validate', action='store_true') parser.add_argument('--pos_weight', action='store_true') parser.add_argument('--checkpoint', type=str, default=None) ### Data related arguments ### parser.add_argument('--frames_per_clip', type=int, default=3) parser.add_argument('--dataset_root_dir', type=str, default=None) parser.add_argument('--positive_labels', type=str, default=None, required=True) parser.add_argument('--excluded_folders', type=str, default='') parser.add_argument('--balanced_dataset', action='store_true') ### Logging related arguments ### parser.add_argument('--exp_name', type=str, default="") parser.add_argument('--use_tb', action='store_true') ### Scheduler related arguments ### parser.add_argument('--scheduler_gamma', type=float, default=0.1) parser.add_argument('--scheduler_step_size', type=int, default=1) parser.add_argument('--scheduler_step_till_epoch', type=int, default=10) ### Model related arguments ### parser.add_argument('--model_num_hidden_layers', type=int, default=1) parser.add_argument('--model_hidden_dim', type=int, default=10) ### Input related arguments ### # Convert the input `x = [frame_embedding1, frame_embedding2, ..., frame_embedding{n}]` to # [frame_embedding1, frame_embedding2-frame_embedding1, ..., frame_embedding_{n}-frame_embedding_{n-1}] parser.add_argument('--frame_diff', action='store_true') args = parser.parse_args() print('Arguments:', args) # Annotate log file # to activate tensorboard: python3 -m tensorboard.main --logdir=logs writer = SummaryWriter(log_dir= f'../logs/{args.exp_name}',comment=f'_{args.exp_name}') if args.use_tb else None # it is assumed that positive labels and excluded folders are passed as a string of labels separated by space positive_labels = args.positive_labels.split(' ') excluded_folders = args.excluded_folders.split(' ') ds = DatasetFolder(root=args.dataset_root_dir, frames_per_clip=args.frames_per_clip, transform_train=batch_transform_train, transform_val=batch_transform_val, excluded_folders=excluded_folders, balanced_dataset=args.balanced_dataset) ds.target_transform = get_target_transform(ds.class_to_idx, *positive_labels) print(f'== DatasetFolder class: {ds.classes}') print(f'== Class_to_idx: {ds.class_to_idx}') print(f'== Number of samples in each class: {ds.class_n_samples}') if args.pos_weight: print(f'== Positive weight: {ds.pos_weight} \n') loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([ds.pos_weight]).to(device)) if args.pos_weight else nn.BCEWithLogitsLoss() ### START: Create training and validation sets ### indices = list(range(len(ds))) np.random.shuffle(indices) split_train = int(np.floor(TRAINING_P * len(indices))) train_indices, val_indices = indices[:split_train], indices[split_train:] ds.val_indices = val_indices ds.set_target_to_indices_dict(ds.target_transform, ds.samples) train_sampler = SubsetRandomSampler(train_indices) val_sampler = SubsetRandomSampler(val_indices) kwargs = {'pin_memory': True, 'num_workers': 4} if device == 'cuda' else {} train_loader = DataLoader(ds, sampler=train_sampler, batch_size=args.batch_size, **kwargs) val_loader = DataLoader(ds, sampler=val_sampler, batch_size=args.batch_size, **kwargs) ### END: Create training and validation sets ### scoringModel = ScoringModel(frames_per_clip=args.frames_per_clip, input_dim=EMBED_DIM, hidden_dim=args.model_hidden_dim, num_hidden_layers=args.model_num_hidden_layers) optimizer = Adam([{'params': scoringModel.parameters()}], lr=args.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma) if args.checkpoint: print('Loading checkpoint ', args.checkpoint) assert os.path.exists(args.checkpoint), 'File not found: {}'.format(args.checkpoint) ckpt_dict = torch.load(args.checkpoint, map_location='cpu') scoringModel = ckpt_dict['model'] optimizer.load_state_dict(ckpt_dict['optimizer_state_dict']) scheduler.load_state_dict(ckpt_dict['scheduler_state_dict']) scoringModel.to(device) if args.validate: scoringModel.eval() validate(None) else: train_losses, val_losses = [], [] for epoch in range(args.n_epochs): scoringModel.train() loss = train(epoch) train_losses.append(loss) loss = validate(epoch) val_losses.append(loss) if epoch < args.scheduler_step_till_epoch: scheduler.step() ckpt_file = os.path.join(CHECKPOINTS_DIRECTORY, f'ckpt_epoch_{epoch}_loss_{train_losses[-1]}.ckpt') ckpt = {'model': scoringModel, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'args': vars(args), 'epoch': epoch, 'train_losses': train_losses, 'val_losses': val_losses} save_checkpoint(ckpt, ckpt_file) print('=== Training over ===') print('Train losses:', train_losses) print('Val losses:', val_losses)