Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
import argparse | |
import os | |
from torch.optim import Adam, lr_scheduler | |
from torch.utils.data import DataLoader | |
from torch.utils.data.sampler import SubsetRandomSampler | |
from batch_image_transforms import batch_transform_train, batch_transform_val | |
from i2v import i2v_transform, EMBED_DIM, load_weight | |
from data import DatasetFolder, get_target_transform | |
from model import ScoringModel | |
from metric import binary_accuracy, precision, recall | |
from torch.utils.tensorboard import SummaryWriter | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
CWD = os.path.dirname(__file__) | |
TRAINING_P = 0.8 | |
CHECKPOINTS_DIRECTORY=os.path.join(CWD, '..', 'checkpoints') | |
WEIGHT_i2v = os.path.join(CWD, '..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt') | |
i2v = load_weight(WEIGHT_i2v) | |
i2v.to(device) | |
i2v_transform.to(device) | |
def save_checkpoint(dict_to_save: dict, path: str): | |
directory = os.path.dirname(path) | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
torch.save(dict_to_save, path) | |
def write_tb_logs(stored_value, epoch, folder:str= "Loss/train"): | |
if writer is not None: | |
writer.add_scalar(folder, stored_value, epoch) | |
writer.flush() | |
def train(epoch): | |
scoringModel.train() | |
total_n, total_loss, total_acc = 0, 0.0, 0.0 | |
total_precision, total_recall = 0.0, 0.0 | |
for index, (x,y_true) in enumerate(train_loader): | |
x = x.to(device) # (batch_size, frames_per_clip, 3, 224, 224) | |
y_true = y_true.to(device).float() | |
batch_size = x.size(0) | |
x = x.view(-1, 3, 224, 224) # (batch_size * frames_per_clip, 3, 224 ,224) | |
with torch.no_grad(): | |
x, _ = i2v(i2v_transform(x)) # (batch_size * frames_per_clip, 512) | |
x = x.view(batch_size, -1, EMBED_DIM) # (batch_size, frames_per_clip, 512) | |
if args.frame_diff: | |
x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] | |
x = x.view(batch_size, -1) | |
y_pred = scoringModel(x).view(-1) | |
acc = binary_accuracy(y_pred, y_true) | |
pr = precision(y_pred, y_true) | |
rc = recall(y_pred, y_true) | |
loss = loss_fn(y_pred, y_true) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if index % 10 == 0: | |
print('Loss[Epoch:{}, Iteration:{}] = {:.8f}'.format(epoch, index, loss.item())) | |
total_acc += acc * batch_size | |
total_precision += pr | |
total_recall += rc | |
total_loss += loss.item() * batch_size | |
total_n += batch_size | |
avg_loss = total_loss / total_n | |
avg_acc = total_acc / total_n | |
avg_precision = total_precision / len(train_loader) | |
avg_recall = total_recall / len(train_loader) | |
print('Finished training for epoch {} with average loss/acc: {}/{}'.format(epoch, avg_loss, avg_acc)) | |
write_tb_logs(avg_loss, epoch, "Loss/train") | |
write_tb_logs(avg_acc, epoch, "Accuracy/train") | |
write_tb_logs(avg_precision, epoch, "Precision/train") | |
write_tb_logs(avg_recall, epoch, "Recall/train") | |
return avg_loss | |
def validate(epoch): | |
scoringModel.eval() | |
if epoch: | |
print('Calculating validation loss for epoch', epoch) | |
else: | |
print('Calculating validation loss') | |
total_n, total_loss, total_acc = 0, 0.0, 0.0 | |
total_precision, total_recall = 0.0, 0.0 | |
for x, y_true in val_loader: | |
x = x.to(device) # (batch_size, frames_per_clip, 3, 224, 224) | |
y_true = y_true.to(device).float() | |
batch_size = x.size(0) | |
x = x.view(-1, 3, 224, 224) # (batch_size * frames_per_clip, 3, 224 ,224) | |
with torch.no_grad(): | |
x, _ = i2v(i2v_transform(x)) # (batch_size * frames_per_clip, 512) | |
x = x.view(batch_size, -1, EMBED_DIM) # (batch_size, frames_per_clip, 512) | |
if args.frame_diff: | |
x[:, 1:, :] = x[:, 1:, :] - x[:, :-1, :] | |
x = x.view(batch_size, -1) | |
y_pred = scoringModel(x).view(-1) | |
loss = loss_fn(y_pred, y_true) | |
acc = binary_accuracy(y_pred, y_true) | |
pr = precision(y_pred, y_true) | |
rc = recall(y_pred, y_true) | |
total_acc += acc * batch_size | |
total_precision += pr | |
total_recall += rc | |
total_loss += loss.item() * batch_size | |
total_n += batch_size | |
avg_loss = total_loss / total_n | |
avg_acc = total_acc / total_n | |
avg_precision = total_precision / len(val_loader) | |
avg_recall = total_recall / len(val_loader) | |
print('Finished calculating validation loss/acc: {}/{}'.format(avg_loss, avg_acc)) | |
write_tb_logs(avg_loss, epoch, "Loss/val") | |
write_tb_logs(avg_acc, epoch, "Accuracy/val") | |
write_tb_logs(avg_precision, epoch, "Precision/val") | |
write_tb_logs(avg_recall, epoch, "Recall/val") | |
return avg_loss | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
### Training related arguments ### | |
parser.add_argument('--batch_size', type=int, default=1) | |
parser.add_argument('--n_epochs', type=int, default=1) | |
parser.add_argument('--lr', type=float, default=1e-4) | |
parser.add_argument('--validate', action='store_true') | |
parser.add_argument('--pos_weight', action='store_true') | |
parser.add_argument('--checkpoint', type=str, default=None) | |
### Data related arguments ### | |
parser.add_argument('--frames_per_clip', type=int, default=3) | |
parser.add_argument('--dataset_root_dir', type=str, default=None) | |
parser.add_argument('--positive_labels', type=str, default=None, required=True) | |
parser.add_argument('--excluded_folders', type=str, default='') | |
parser.add_argument('--balanced_dataset', action='store_true') | |
### Logging related arguments ### | |
parser.add_argument('--exp_name', type=str, default="") | |
parser.add_argument('--use_tb', action='store_true') | |
### Scheduler related arguments ### | |
parser.add_argument('--scheduler_gamma', type=float, default=0.1) | |
parser.add_argument('--scheduler_step_size', type=int, default=1) | |
parser.add_argument('--scheduler_step_till_epoch', type=int, default=10) | |
### Model related arguments ### | |
parser.add_argument('--model_num_hidden_layers', type=int, default=1) | |
parser.add_argument('--model_hidden_dim', type=int, default=10) | |
### Input related arguments ### | |
# Convert the input `x = [frame_embedding1, frame_embedding2, ..., frame_embedding{n}]` to | |
# [frame_embedding1, frame_embedding2-frame_embedding1, ..., frame_embedding_{n}-frame_embedding_{n-1}] | |
parser.add_argument('--frame_diff', action='store_true') | |
args = parser.parse_args() | |
print('Arguments:', args) | |
# Annotate log file | |
# to activate tensorboard: python3 -m tensorboard.main --logdir=logs | |
writer = SummaryWriter(log_dir= f'../logs/{args.exp_name}',comment=f'_{args.exp_name}') if args.use_tb else None | |
# it is assumed that positive labels and excluded folders are passed as a string of labels separated by space | |
positive_labels = args.positive_labels.split(' ') | |
excluded_folders = args.excluded_folders.split(' ') | |
ds = DatasetFolder(root=args.dataset_root_dir, frames_per_clip=args.frames_per_clip, transform_train=batch_transform_train, transform_val=batch_transform_val, excluded_folders=excluded_folders, balanced_dataset=args.balanced_dataset) | |
ds.target_transform = get_target_transform(ds.class_to_idx, *positive_labels) | |
print(f'== DatasetFolder class: {ds.classes}') | |
print(f'== Class_to_idx: {ds.class_to_idx}') | |
print(f'== Number of samples in each class: {ds.class_n_samples}') | |
if args.pos_weight: | |
print(f'== Positive weight: {ds.pos_weight} \n') | |
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([ds.pos_weight]).to(device)) if args.pos_weight else nn.BCEWithLogitsLoss() | |
### START: Create training and validation sets ### | |
indices = list(range(len(ds))) | |
np.random.shuffle(indices) | |
split_train = int(np.floor(TRAINING_P * len(indices))) | |
train_indices, val_indices = indices[:split_train], indices[split_train:] | |
ds.val_indices = val_indices | |
ds.set_target_to_indices_dict(ds.target_transform, ds.samples) | |
train_sampler = SubsetRandomSampler(train_indices) | |
val_sampler = SubsetRandomSampler(val_indices) | |
kwargs = {'pin_memory': True, 'num_workers': 4} if device == 'cuda' else {} | |
train_loader = DataLoader(ds, sampler=train_sampler, batch_size=args.batch_size, **kwargs) | |
val_loader = DataLoader(ds, sampler=val_sampler, batch_size=args.batch_size, **kwargs) | |
### END: Create training and validation sets ### | |
scoringModel = ScoringModel(frames_per_clip=args.frames_per_clip, input_dim=EMBED_DIM, hidden_dim=args.model_hidden_dim, num_hidden_layers=args.model_num_hidden_layers) | |
optimizer = Adam([{'params': scoringModel.parameters()}], lr=args.lr) | |
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma) | |
if args.checkpoint: | |
print('Loading checkpoint ', args.checkpoint) | |
assert os.path.exists(args.checkpoint), 'File not found: {}'.format(args.checkpoint) | |
ckpt_dict = torch.load(args.checkpoint, map_location='cpu') | |
scoringModel = ckpt_dict['model'] | |
optimizer.load_state_dict(ckpt_dict['optimizer_state_dict']) | |
scheduler.load_state_dict(ckpt_dict['scheduler_state_dict']) | |
scoringModel.to(device) | |
if args.validate: | |
scoringModel.eval() | |
validate(None) | |
else: | |
train_losses, val_losses = [], [] | |
for epoch in range(args.n_epochs): | |
scoringModel.train() | |
loss = train(epoch) | |
train_losses.append(loss) | |
loss = validate(epoch) | |
val_losses.append(loss) | |
if epoch < args.scheduler_step_till_epoch: | |
scheduler.step() | |
ckpt_file = os.path.join(CHECKPOINTS_DIRECTORY, f'ckpt_epoch_{epoch}_loss_{train_losses[-1]}.ckpt') | |
ckpt = {'model': scoringModel, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'args': vars(args), 'epoch': epoch, 'train_losses': train_losses, 'val_losses': val_losses} | |
save_checkpoint(ckpt, ckpt_file) | |
print('=== Training over ===') | |
print('Train losses:', train_losses) | |
print('Val losses:', val_losses) | |