|
from typing import Any, Dict |
|
from schema import Schema |
|
from data import Scenario, MergedDataset |
|
from methods.base.alg import BaseAlg |
|
from data import build_dataloader |
|
from ..model import ElasticDNN_OfflineFMModel |
|
from ...model.base import ElasticDNNUtil |
|
import torch.optim |
|
import tqdm |
|
from torch import nn |
|
from torchvision.transforms import Compose |
|
from utils.dl.common.env import create_tbwriter |
|
import os |
|
import random |
|
import numpy as np |
|
from copy import deepcopy |
|
from utils.dl.common.model import get_module |
|
from utils.common.log import logger |
|
|
|
|
|
class ElasticDNN_FMLoRAAlg(BaseAlg): |
|
def get_required_models_schema(self) -> Schema: |
|
return Schema({ |
|
'fm': ElasticDNN_OfflineFMModel |
|
}) |
|
|
|
def get_required_hyp_schema(self) -> Schema: |
|
from schema import Optional |
|
return Schema({ |
|
'launch_tbboard': bool, |
|
|
|
'samples_size': object, |
|
'ab_r': int, |
|
|
|
'train_batch_size': int, |
|
'val_batch_size': int, |
|
'num_workers': int, |
|
'optimizer': str, |
|
'optimizer_args': dict, |
|
'scheduler': str, |
|
'scheduler_args': dict, |
|
'num_iters': int, |
|
'val_freq': int, |
|
|
|
Optional('fm_lora_ckpt_path'): str, |
|
Optional('transform'): Compose, |
|
}) |
|
|
|
def run(self, scenario: Scenario, hyps: Dict, collate_fn=None) -> Dict[str, Any]: |
|
super().run(scenario, hyps) |
|
|
|
assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) |
|
|
|
|
|
lora_util = self.models['fm'].get_lora_util() |
|
device = self.models['fm'].device |
|
|
|
sample = hyps['samples_size'] |
|
if isinstance(sample, (tuple, list)) and isinstance(sample[0], int): |
|
sample = torch.rand(hyps['samples_size']).to(device) |
|
lora_util.add_lora_ab_to_fm(self.models['fm'].models_dict['main'], hyps['ab_r'], sample) |
|
|
|
if 'fm_lora_ckpt_path' in hyps.keys() and hyps['fm_lora_ckpt_path'] != '' and hyps['fm_lora_ckpt_path'] is not None: |
|
_ckpt = torch.load(hyps['fm_lora_ckpt_path'])['main'] |
|
|
|
new_state_dict = deepcopy(self.models['fm'].models_dict['main'].state_dict()) |
|
|
|
for n, p in _ckpt.named_parameters(): |
|
if 'qkv.abs' not in n: |
|
continue |
|
|
|
new_state_dict[n] = p |
|
logger.info(f'use {n} from ckpt') |
|
|
|
self.models['fm'].models_dict['main'].load_state_dict(new_state_dict) |
|
|
|
|
|
|
|
if 'transform' in hyps.keys(): |
|
offline_datasets = scenario.get_offline_datasets(transform=hyps['transform']) |
|
else: |
|
offline_datasets = scenario.get_offline_datasets() |
|
train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) |
|
train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], |
|
True, None, collate_fn=collate_fn)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], |
|
False, False, collate_fn=collate_fn) |
|
|
|
lora_params = lora_util.train_only_lora(self.models['fm'].models_dict['main']) |
|
head_params = self.models['fm'].get_task_head_params() |
|
|
|
num_lora_params = sum([np.prod(p.size()) for p in lora_params]) |
|
total_params = sum([np.prod(p.size()) for p in self.models['fm'].models_dict['main'].parameters()]) |
|
logger.info(f'num lora params: {num_lora_params}, total params: {total_params}, ratio: {num_lora_params / total_params}') |
|
|
|
optimizer = torch.optim.__dict__[hyps['optimizer']](lora_params + head_params, **hyps['optimizer_args']) |
|
scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) |
|
|
|
fbs_tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) |
|
pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) |
|
|
|
best_val_acc = 0 |
|
val_acc = 0 |
|
|
|
for iter_index in pbar: |
|
self.models['fm'].to_train_mode() |
|
|
|
x, y = next(train_loader) |
|
|
|
if isinstance(x, dict): |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(device) |
|
y = y.to(device) |
|
else: |
|
x, y = x.to(device), y.to(device) |
|
task_loss = self.models['fm'].forward_to_get_task_loss(x, y) |
|
optimizer.zero_grad() |
|
task_loss.backward() |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
if (iter_index + 1) % hyps['val_freq'] == 0: |
|
|
|
|
|
self.models['fm'].to_eval_mode() |
|
val_acc = self.models['fm'].get_accuracy(val_loader) |
|
|
|
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) |
|
if val_acc > best_val_acc: |
|
best_val_acc = val_acc |
|
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) |
|
|
|
fbs_tb_writer.add_scalar(f'losses/task_loss', task_loss, iter_index) |
|
fbs_tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index) |
|
fbs_tb_writer.add_scalar(f'lr', optimizer.param_groups[0]['lr'], iter_index) |
|
pbar.set_description(f'loss: {task_loss:.6f}, val_acc: {val_acc:.4f}') |
|
|