LINC-BIT's picture
Upload 1912 files
b84549f verified
from typing import Any, Dict
from schema import Schema
from data import Scenario, MergedDataset
from methods.base.alg import BaseAlg
from data import build_dataloader
from ..model import ElasticDNN_OfflineFMModel
from ...model.base import ElasticDNNUtil
import torch.optim
import tqdm
from torch import nn
from torchvision.transforms import Compose
from utils.dl.common.env import create_tbwriter
import os
import random
import numpy as np
from copy import deepcopy
from utils.dl.common.model import get_module
from utils.common.log import logger
class ElasticDNN_FMLoRAAlg(BaseAlg):
def get_required_models_schema(self) -> Schema:
return Schema({
'fm': ElasticDNN_OfflineFMModel
})
def get_required_hyp_schema(self) -> Schema:
from schema import Optional
return Schema({
'launch_tbboard': bool,
'samples_size': object,
'ab_r': int,
'train_batch_size': int,
'val_batch_size': int,
'num_workers': int,
'optimizer': str,
'optimizer_args': dict,
'scheduler': str,
'scheduler_args': dict,
'num_iters': int,
'val_freq': int,
Optional('fm_lora_ckpt_path'): str,
Optional('transform'): Compose,
})
def run(self, scenario: Scenario, hyps: Dict, collate_fn=None) -> Dict[str, Any]:
super().run(scenario, hyps)
assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion
# 1. add LoRA
lora_util = self.models['fm'].get_lora_util()
device = self.models['fm'].device
sample = hyps['samples_size']
if isinstance(sample, (tuple, list)) and isinstance(sample[0], int):
sample = torch.rand(hyps['samples_size']).to(device)
lora_util.add_lora_ab_to_fm(self.models['fm'].models_dict['main'], hyps['ab_r'], sample)
if 'fm_lora_ckpt_path' in hyps.keys() and hyps['fm_lora_ckpt_path'] != '' and hyps['fm_lora_ckpt_path'] is not None:
_ckpt = torch.load(hyps['fm_lora_ckpt_path'])['main']
new_state_dict = deepcopy(self.models['fm'].models_dict['main'].state_dict())
for n, p in _ckpt.named_parameters():
if 'qkv.abs' not in n:
continue
new_state_dict[n] = p
logger.info(f'use {n} from ckpt')
self.models['fm'].models_dict['main'].load_state_dict(new_state_dict)
# 2. train (knowledge distillation, index relationship)
if 'transform' in hyps.keys():
offline_datasets = scenario.get_offline_datasets(transform=hyps['transform'])
else:
offline_datasets = scenario.get_offline_datasets()
train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()])
# debug
# from data.visualize import visualize_classes_in_object_detection
# d = offline_datasets['GTA5Det']['val']
# class_to_idx_map = {c: d.idx_map[i] for i, c in enumerate(d.classes)}
# print(class_to_idx_map)
# visualize_classes_in_object_detection(d, class_to_idx_map,
# {}, os.path.join(self.res_save_dir, 'debug.png'))
# exit()
val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()])
train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'],
True, None, collate_fn=collate_fn))
# if hyps['use_train_loader_for_val']:
# val_loader = build_dataloader(train_dataset, hyps['val_batch_size'], hyps['num_workers'],
# False, False)
# logger.warn('use train loader for val!!!')
# else:
val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'],
False, False, collate_fn=collate_fn)
lora_params = lora_util.train_only_lora(self.models['fm'].models_dict['main'])
head_params = self.models['fm'].get_task_head_params()
num_lora_params = sum([np.prod(p.size()) for p in lora_params])
total_params = sum([np.prod(p.size()) for p in self.models['fm'].models_dict['main'].parameters()])
logger.info(f'num lora params: {num_lora_params}, total params: {total_params}, ratio: {num_lora_params / total_params}')
optimizer = torch.optim.__dict__[hyps['optimizer']](lora_params + head_params, **hyps['optimizer_args'])
scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args'])
fbs_tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard'])
pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True)
best_val_acc = 0
val_acc = 0
for iter_index in pbar:
self.models['fm'].to_train_mode()
x, y = next(train_loader)
if isinstance(x, dict):
for k, v in x.items():
if isinstance(v, torch.Tensor):
x[k] = v.to(device)
y = y.to(device)
else:
x, y = x.to(device), y.to(device)
task_loss = self.models['fm'].forward_to_get_task_loss(x, y)
optimizer.zero_grad()
task_loss.backward()
optimizer.step()
scheduler.step()
if (iter_index + 1) % hyps['val_freq'] == 0:
# logger.warn('use train loader for val!!!')
self.models['fm'].to_eval_mode()
val_acc = self.models['fm'].get_accuracy(val_loader)
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt'))
if val_acc > best_val_acc:
best_val_acc = val_acc
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt'))
fbs_tb_writer.add_scalar(f'losses/task_loss', task_loss, iter_index)
fbs_tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index)
fbs_tb_writer.add_scalar(f'lr', optimizer.param_groups[0]['lr'], iter_index)
pbar.set_description(f'loss: {task_loss:.6f}, val_acc: {val_acc:.4f}')