AnsenH's picture
feat: add our model
24615d9
import torch
import torch.nn as nn
import numpy as np
import argparse
import os
from torch.optim import Adam, lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from batch_image_transforms import batch_transform_train, batch_transform_val
from i2v import i2v_transform, EMBED_DIM, load_weight
from data import DatasetFolder, get_target_transform
from model import ScoringModel
from metric import binary_accuracy, precision, recall
from torch.utils.tensorboard import SummaryWriter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
CWD = os.path.dirname(__file__)
TRAINING_P = 0.8
CHECKPOINTS_DIRECTORY=os.path.join(CWD, '..', 'checkpoints')
WEIGHT_i2v = os.path.join(CWD, '..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt')
i2v = load_weight(WEIGHT_i2v)
i2v.to(device)
i2v_transform.to(device)
def save_checkpoint(dict_to_save: dict, path: str):
directory = os.path.dirname(path)
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(dict_to_save, path)
def write_tb_logs(stored_value, epoch, folder:str= "Loss/train"):
if writer is not None:
writer.add_scalar(folder, stored_value, epoch)
writer.flush()
def train(epoch):
scoringModel.train()
total_n, total_loss, total_acc = 0, 0.0, 0.0
total_precision, total_recall = 0.0, 0.0
for index, (x,y_true) in enumerate(train_loader):
x = x.to(device) # (batch_size, frames_per_clip, 3, 224, 224)
y_true = y_true.to(device).float()
batch_size = x.size(0)
x = x.view(-1, 3, 224, 224) # (batch_size * frames_per_clip, 3, 224 ,224)
with torch.no_grad():
x, _ = i2v(i2v_transform(x)) # (batch_size * frames_per_clip, 512)
x = x.view(batch_size, -1, EMBED_DIM) # (batch_size, frames_per_clip, 512)
if args.frame_diff:
x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
x = x.view(batch_size, -1)
y_pred = scoringModel(x).view(-1)
acc = binary_accuracy(y_pred, y_true)
pr = precision(y_pred, y_true)
rc = recall(y_pred, y_true)
loss = loss_fn(y_pred, y_true)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if index % 10 == 0:
print('Loss[Epoch:{}, Iteration:{}] = {:.8f}'.format(epoch, index, loss.item()))
total_acc += acc * batch_size
total_precision += pr
total_recall += rc
total_loss += loss.item() * batch_size
total_n += batch_size
avg_loss = total_loss / total_n
avg_acc = total_acc / total_n
avg_precision = total_precision / len(train_loader)
avg_recall = total_recall / len(train_loader)
print('Finished training for epoch {} with average loss/acc: {}/{}'.format(epoch, avg_loss, avg_acc))
write_tb_logs(avg_loss, epoch, "Loss/train")
write_tb_logs(avg_acc, epoch, "Accuracy/train")
write_tb_logs(avg_precision, epoch, "Precision/train")
write_tb_logs(avg_recall, epoch, "Recall/train")
return avg_loss
def validate(epoch):
scoringModel.eval()
if epoch:
print('Calculating validation loss for epoch', epoch)
else:
print('Calculating validation loss')
total_n, total_loss, total_acc = 0, 0.0, 0.0
total_precision, total_recall = 0.0, 0.0
for x, y_true in val_loader:
x = x.to(device) # (batch_size, frames_per_clip, 3, 224, 224)
y_true = y_true.to(device).float()
batch_size = x.size(0)
x = x.view(-1, 3, 224, 224) # (batch_size * frames_per_clip, 3, 224 ,224)
with torch.no_grad():
x, _ = i2v(i2v_transform(x)) # (batch_size * frames_per_clip, 512)
x = x.view(batch_size, -1, EMBED_DIM) # (batch_size, frames_per_clip, 512)
if args.frame_diff:
x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :]
x = x.view(batch_size, -1)
y_pred = scoringModel(x).view(-1)
loss = loss_fn(y_pred, y_true)
acc = binary_accuracy(y_pred, y_true)
pr = precision(y_pred, y_true)
rc = recall(y_pred, y_true)
total_acc += acc * batch_size
total_precision += pr
total_recall += rc
total_loss += loss.item() * batch_size
total_n += batch_size
avg_loss = total_loss / total_n
avg_acc = total_acc / total_n
avg_precision = total_precision / len(val_loader)
avg_recall = total_recall / len(val_loader)
print('Finished calculating validation loss/acc: {}/{}'.format(avg_loss, avg_acc))
write_tb_logs(avg_loss, epoch, "Loss/val")
write_tb_logs(avg_acc, epoch, "Accuracy/val")
write_tb_logs(avg_precision, epoch, "Precision/val")
write_tb_logs(avg_recall, epoch, "Recall/val")
return avg_loss
if __name__ == '__main__':
parser = argparse.ArgumentParser()
### Training related arguments ###
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--n_epochs', type=int, default=1)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--validate', action='store_true')
parser.add_argument('--pos_weight', action='store_true')
parser.add_argument('--checkpoint', type=str, default=None)
### Data related arguments ###
parser.add_argument('--frames_per_clip', type=int, default=3)
parser.add_argument('--dataset_root_dir', type=str, default=None)
parser.add_argument('--positive_labels', type=str, default=None, required=True)
parser.add_argument('--excluded_folders', type=str, default='')
parser.add_argument('--balanced_dataset', action='store_true')
### Logging related arguments ###
parser.add_argument('--exp_name', type=str, default="")
parser.add_argument('--use_tb', action='store_true')
### Scheduler related arguments ###
parser.add_argument('--scheduler_gamma', type=float, default=0.1)
parser.add_argument('--scheduler_step_size', type=int, default=1)
parser.add_argument('--scheduler_step_till_epoch', type=int, default=10)
### Model related arguments ###
parser.add_argument('--model_num_hidden_layers', type=int, default=1)
parser.add_argument('--model_hidden_dim', type=int, default=10)
### Input related arguments ###
# Convert the input `x = [frame_embedding1, frame_embedding2, ..., frame_embedding{n}]` to
# [frame_embedding1, frame_embedding2-frame_embedding1, ..., frame_embedding_{n}-frame_embedding_{n-1}]
parser.add_argument('--frame_diff', action='store_true')
args = parser.parse_args()
print('Arguments:', args)
# Annotate log file
# to activate tensorboard: python3 -m tensorboard.main --logdir=logs
writer = SummaryWriter(log_dir= f'../logs/{args.exp_name}',comment=f'_{args.exp_name}') if args.use_tb else None
# it is assumed that positive labels and excluded folders are passed as a string of labels separated by space
positive_labels = args.positive_labels.split(' ')
excluded_folders = args.excluded_folders.split(' ')
ds = DatasetFolder(root=args.dataset_root_dir, frames_per_clip=args.frames_per_clip, transform_train=batch_transform_train, transform_val=batch_transform_val, excluded_folders=excluded_folders, balanced_dataset=args.balanced_dataset)
ds.target_transform = get_target_transform(ds.class_to_idx, *positive_labels)
print(f'== DatasetFolder class: {ds.classes}')
print(f'== Class_to_idx: {ds.class_to_idx}')
print(f'== Number of samples in each class: {ds.class_n_samples}')
if args.pos_weight:
print(f'== Positive weight: {ds.pos_weight} \n')
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([ds.pos_weight]).to(device)) if args.pos_weight else nn.BCEWithLogitsLoss()
### START: Create training and validation sets ###
indices = list(range(len(ds)))
np.random.shuffle(indices)
split_train = int(np.floor(TRAINING_P * len(indices)))
train_indices, val_indices = indices[:split_train], indices[split_train:]
ds.val_indices = val_indices
ds.set_target_to_indices_dict(ds.target_transform, ds.samples)
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
kwargs = {'pin_memory': True, 'num_workers': 4} if device == 'cuda' else {}
train_loader = DataLoader(ds, sampler=train_sampler, batch_size=args.batch_size, **kwargs)
val_loader = DataLoader(ds, sampler=val_sampler, batch_size=args.batch_size, **kwargs)
### END: Create training and validation sets ###
scoringModel = ScoringModel(frames_per_clip=args.frames_per_clip, input_dim=EMBED_DIM, hidden_dim=args.model_hidden_dim, num_hidden_layers=args.model_num_hidden_layers)
optimizer = Adam([{'params': scoringModel.parameters()}], lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma)
if args.checkpoint:
print('Loading checkpoint ', args.checkpoint)
assert os.path.exists(args.checkpoint), 'File not found: {}'.format(args.checkpoint)
ckpt_dict = torch.load(args.checkpoint, map_location='cpu')
scoringModel = ckpt_dict['model']
optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])
scheduler.load_state_dict(ckpt_dict['scheduler_state_dict'])
scoringModel.to(device)
if args.validate:
scoringModel.eval()
validate(None)
else:
train_losses, val_losses = [], []
for epoch in range(args.n_epochs):
scoringModel.train()
loss = train(epoch)
train_losses.append(loss)
loss = validate(epoch)
val_losses.append(loss)
if epoch < args.scheduler_step_till_epoch:
scheduler.step()
ckpt_file = os.path.join(CHECKPOINTS_DIRECTORY, f'ckpt_epoch_{epoch}_loss_{train_losses[-1]}.ckpt')
ckpt = {'model': scoringModel, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'args': vars(args), 'epoch': epoch, 'train_losses': train_losses, 'val_losses': val_losses}
save_checkpoint(ckpt, ckpt_file)
print('=== Training over ===')
print('Train losses:', train_losses)
print('Val losses:', val_losses)