# from typing import Any, Dict | |
# from schema import Schema, Or | |
# import schema | |
# from data import Scenario, MergedDataset | |
# from methods.base.alg import BaseAlg | |
# from data import build_dataloader | |
# from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
# from ...model.base import ElasticDNNUtil | |
# import torch.optim | |
# import tqdm | |
# import torch.nn.functional as F | |
# from torch import nn | |
# from utils.dl.common.env import create_tbwriter | |
# import os | |
# import random | |
# import numpy as np | |
# from copy import deepcopy | |
# from utils.dl.common.model import LayerActivation2, get_module | |
# from utils.common.log import logger | |
# class ElasticDNN_MDPretrainingWoFBSAlg(BaseAlg): | |
# """ | |
# TODO: fine-tuned FM -> init MD -> trained MD -> construct indexes (only between similar weights) and fine-tune | |
# """ | |
# def get_required_models_schema(self) -> Schema: | |
# return Schema({ | |
# 'fm': ElasticDNN_OfflineFMModel, | |
# 'md': ElasticDNN_OfflineMDModel | |
# }) | |
# def get_required_hyp_schema(self) -> Schema: | |
# return Schema({ | |
# 'launch_tbboard': bool, | |
# 'samples_size': object, | |
# 'generate_md_width_ratio': int, | |
# 'train_batch_size': int, | |
# 'val_batch_size': int, | |
# 'num_workers': int, | |
# 'optimizer': str, | |
# 'optimizer_args': dict, | |
# 'scheduler': str, | |
# 'scheduler_args': dict, | |
# 'num_iters': int, | |
# 'val_freq': int, | |
# 'distill_loss_weight': float | |
# }) | |
# def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]: | |
# super().run(scenario, hyps) | |
# assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion | |
# assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion | |
# # 1. add FBS | |
# device = self.models['md'].device | |
# if self.models['md'].models_dict['main'] == -1: | |
# logger.info(f'init master DNN by reducing width of an adapted foundation model (already tuned by LoRA)...') | |
# before_fm_model = deepcopy(self.models['fm'].models_dict['main']) | |
# lora_util = self.models['fm'].get_lora_util() | |
# sample = hyps['samples_size'] | |
# if isinstance(sample, (tuple, list)) and isinstance(sample[0], int): | |
# sample = torch.rand(hyps['samples_size']).to(device) | |
# lora_absorbed_fm_model = lora_util.absorb_lora_and_recover_net_structure(self.models['fm'].models_dict['main'], | |
# sample) | |
# self.models['fm'].models_dict['main'] = lora_absorbed_fm_model | |
# master_dnn = self.models['fm'].generate_md_by_reducing_width(hyps['generate_md_width_ratio'], | |
# sample) | |
# self.models['fm'].models_dict['main'] = before_fm_model | |
# self.models['md'].models_dict['main'] = master_dnn | |
# self.models['md'].to(device) | |
# # 2. train (knowledge distillation, index relationship) | |
# offline_datasets = scenario.get_offline_datasets() | |
# train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) | |
# val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) | |
# train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], | |
# True, None)) | |
# val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], | |
# False, False) | |
# # val_acc = self.models['md'].get_accuracy(val_loader) | |
# # print(val_acc) | |
# # exit() | |
# # 2.1 train whole master DNN (knowledge distillation) | |
# self.models['md'].to_train_mode() | |
# for p in master_dnn.parameters(): | |
# p.requires_grad = True | |
# if hasattr(self.models['md'], 'get_trained_params'): | |
# trained_p = self.models['md'].get_trained_params() | |
# logger.info(f'use custom trained parameters!!') | |
# else: | |
# trained_p = self.models['md'].models_dict['main'].parameters() | |
# for p in trained_p: | |
# p.requires_grad = True | |
# optimizer = torch.optim.__dict__[hyps['optimizer']]([ | |
# {'params': trained_p, **hyps['optimizer_args']} | |
# ]) | |
# scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) | |
# tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) | |
# pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) | |
# best_avg_val_acc = 0. | |
# md_output_hook = None | |
# for iter_index in pbar: | |
# self.models['md'].to_train_mode() | |
# self.models['fm'].to_eval_mode() | |
# # rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity'] | |
# # elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity) | |
# if md_output_hook is None: | |
# md_output_hook = self.models['md'].get_feature_hook() | |
# fm_output_hook = self.models['fm'].get_feature_hook() | |
# x, y = next(train_loader) | |
# if isinstance(x, dict): | |
# for k, v in x.items(): | |
# if isinstance(v, torch.Tensor): | |
# x[k] = v.to(device) | |
# y = y.to(device) | |
# else: | |
# x, y = x.to(device), y.to(device) | |
# with torch.no_grad(): | |
# fm_output = self.models['fm'].infer(x) | |
# task_loss = self.models['md'].forward_to_get_task_loss(x, y) | |
# if isinstance(md_output_hook, (tuple, list)): | |
# distill_loss = 0. | |
# for h1, h2 in zip(md_output_hook, fm_output_hook): | |
# md_output = h1.output | |
# fm_output = h2.output | |
# distill_loss += hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output) | |
# else: | |
# md_output = md_output_hook.output | |
# fm_output = fm_output_hook.output | |
# distill_loss = hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output) | |
# total_loss = task_loss + distill_loss | |
# optimizer.zero_grad() | |
# total_loss.backward() | |
# # for n, p in self.models['md'].models_dict['main'].named_parameters(): | |
# # if p.grad is not None: | |
# # print(n) | |
# # exit() | |
# optimizer.step() | |
# scheduler.step() | |
# if (iter_index + 1) % hyps['val_freq'] == 0: | |
# # elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main']) | |
# if isinstance(md_output_hook, (tuple, list)): | |
# [h.remove() for h in md_output_hook] | |
# [h.remove() for h in fm_output_hook] | |
# else: | |
# md_output_hook.remove() | |
# fm_output_hook.remove() | |
# md_output_hook = None | |
# fm_output_hook = None | |
# cur_md = self.models['md'].models_dict['main'] | |
# md_for_test = deepcopy(self.models['md'].models_dict['main']) | |
# val_acc = 0. | |
# self.models['md'].models_dict['main'] = md_for_test | |
# self.models['md'].to_eval_mode() | |
# val_acc = self.models['md'].get_accuracy(val_loader) | |
# self.models['md'].models_dict['main'] = cur_md | |
# self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt')) | |
# self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) | |
# if val_acc > best_avg_val_acc: | |
# best_avg_val_acc = val_acc | |
# self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt')) | |
# self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) | |
# tb_writer.add_scalars(f'losses', dict(task=task_loss, distill=distill_loss, total=total_loss), iter_index) | |
# pbar.set_description(f'loss: {total_loss:.6f}') | |
# if (iter_index + 1) >= hyps['val_freq']: | |
# tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index) | |
# pbar.set_description(f'loss: {total_loss:.6f}, val_acc: {val_acc:.4f}') | |
# code below is commented on 0716 17:49, because of a bug that the loss cannot be gradient decented | |
# (bug confirmed, why? I dont know :) | |
from typing import Any, Dict | |
from schema import Schema, Or | |
import schema | |
from data import Scenario, MergedDataset | |
from methods.base.alg import BaseAlg | |
from data import build_dataloader | |
from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
from ...model.base import ElasticDNNUtil | |
import torch.optim | |
import tqdm | |
import torch.nn.functional as F | |
from torch import nn | |
from utils.dl.common.env import create_tbwriter | |
import os | |
import random | |
import numpy as np | |
from copy import deepcopy | |
from utils.dl.common.model import LayerActivation2, get_module | |
from utils.common.log import logger | |
from torchvision.transforms import Compose | |
class ElasticDNN_MDPretrainingWoFBSAlg(BaseAlg): | |
""" | |
TODO: fine-tuned FM -> init MD -> trained MD -> construct indexes (only between similar weights) and fine-tune | |
""" | |
def get_required_models_schema(self) -> Schema: | |
return Schema({ | |
'fm': ElasticDNN_OfflineFMModel, | |
'md': ElasticDNN_OfflineMDModel | |
}) | |
def get_required_hyp_schema(self) -> Schema: | |
from schema import Optional | |
return Schema({ | |
'launch_tbboard': bool, | |
'samples_size': any, | |
'generate_md_width_ratio': int, | |
'train_batch_size': int, | |
'val_batch_size': int, | |
'num_workers': int, | |
'optimizer': str, | |
'optimizer_args': dict, | |
'scheduler': str, | |
'scheduler_args': dict, | |
'num_iters': int, | |
'val_freq': int, | |
'distill_loss_weight': float, | |
Optional('transform'): Compose, | |
}) | |
def run(self, scenario: Scenario, hyps: Dict, collate_fn=None) -> Dict[str, Any]: | |
super().run(scenario, hyps) | |
assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion | |
assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion | |
# 1. add FBS | |
device = self.models['md'].device | |
if self.models['md'].models_dict['main'] == -1: | |
logger.info(f'init master DNN by reducing width of an adapted foundation model (already tuned by LoRA)...') | |
before_fm_model = deepcopy(self.models['fm'].models_dict['main']) | |
lora_util = self.models['fm'].get_lora_util() | |
sample = hyps['samples_size'] | |
if isinstance(sample, (tuple, list)) and isinstance(sample[0], int): | |
sample = torch.rand(hyps['samples_size']).to(device) | |
lora_absorbed_fm_model = lora_util.absorb_lora_and_recover_net_structure(self.models['fm'].models_dict['main'], | |
sample) | |
self.models['fm'].models_dict['main'] = lora_absorbed_fm_model | |
master_dnn = self.models['fm'].generate_md_by_reducing_width(hyps['generate_md_width_ratio'], | |
sample) | |
self.models['fm'].models_dict['main'] = before_fm_model | |
self.models['md'].models_dict['main'] = master_dnn | |
self.models['md'].to(device) | |
# 2. train (knowledge distillation, index relationship) | |
if 'transform' in hyps.keys(): | |
offline_datasets = scenario.get_offline_datasets(transform=hyps['transform']) | |
else: | |
offline_datasets = scenario.get_offline_datasets() | |
train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) | |
val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) | |
train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], | |
True, None, collate_fn=collate_fn)) | |
val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], | |
False, False, collate_fn=collate_fn) | |
# logger.info(f'FM acc: {self.models["fm"].get_accuracy(val_loader):.4f}') | |
# 2.1 train whole master DNN (knowledge distillation) | |
for p in master_dnn.parameters(): | |
p.requires_grad = True | |
self.models['md'].to_train_mode() | |
optimizer = torch.optim.__dict__[hyps['optimizer']]([ | |
{'params': self.models['md'].models_dict['main'].parameters(), **hyps['optimizer_args']} | |
]) | |
scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) | |
tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) | |
pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) | |
best_avg_val_acc = 0. | |
md_output_hook = None | |
for iter_index in pbar: | |
self.models['md'].to_train_mode() | |
self.models['fm'].to_eval_mode() | |
# rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity'] | |
# elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity) | |
if md_output_hook is None: | |
md_output_hook = self.models['md'].get_feature_hook() | |
fm_output_hook = self.models['fm'].get_feature_hook() | |
x, y = next(train_loader) | |
if isinstance(x, dict): | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(device) | |
y = y.to(device) | |
else: | |
x, y = x.to(device), y.to(device) | |
with torch.no_grad(): | |
fm_output = self.models['fm'].infer(x) | |
task_loss = self.models['md'].forward_to_get_task_loss(x, y) | |
md_output = md_output_hook.output | |
fm_output = fm_output_hook.output | |
distill_loss = hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output) | |
total_loss = task_loss + distill_loss | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() | |
scheduler.step() | |
if (iter_index + 1) % hyps['val_freq'] == 0: | |
# elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main']) | |
md_output_hook.remove() | |
md_output_hook = None | |
fm_output_hook.remove() | |
fm_output_hook = None | |
cur_md = self.models['md'].models_dict['main'] | |
md_for_test = deepcopy(self.models['md'].models_dict['main']) | |
val_acc = 0. | |
self.models['md'].models_dict['main'] = md_for_test | |
self.models['md'].to_eval_mode() | |
val_acc = self.models['md'].get_accuracy(val_loader) | |
self.models['md'].models_dict['main'] = cur_md | |
self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt')) | |
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) | |
if val_acc > best_avg_val_acc: | |
best_avg_val_acc = val_acc | |
self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt')) | |
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) | |
tb_writer.add_scalars(f'losses', dict(task=task_loss, distill=distill_loss, total=total_loss), iter_index) | |
pbar.set_description(f'loss: {total_loss:.6f}') | |
if (iter_index + 1) >= hyps['val_freq']: | |
tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index) | |
pbar.set_description(f'loss: {total_loss:.6f}, val_acc: {val_acc:.4f}') | |