|
from typing import Any, Dict |
|
from schema import Schema, Or |
|
import schema |
|
from data import Scenario, MergedDataset |
|
from methods.base.alg import BaseAlg |
|
from data import build_dataloader |
|
from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel |
|
from ...model.base import ElasticDNNUtil |
|
import torch.optim |
|
import tqdm |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from utils.dl.common.env import create_tbwriter |
|
import os |
|
import random |
|
import numpy as np |
|
from copy import deepcopy |
|
from utils.dl.common.model import LayerActivation, get_module |
|
from utils.common.log import logger |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class ElasticDNN_MDPretrainingIndexAlg(BaseAlg): |
|
""" |
|
construct indexes between a filter/row of MD and all filters/rows of FM in the same layer |
|
too huge indexes (~1GB), train so slow, hard to optimize |
|
""" |
|
def get_required_models_schema(self) -> Schema: |
|
return Schema({ |
|
'fm': ElasticDNN_OfflineFMModel, |
|
'md': ElasticDNN_OfflineMDModel |
|
}) |
|
|
|
def get_required_hyp_schema(self) -> Schema: |
|
return Schema({ |
|
'launch_tbboard': bool, |
|
|
|
'samples_size': (int, int, int, int), |
|
|
|
'train_batch_size': int, |
|
'val_batch_size': int, |
|
'num_workers': int, |
|
'optimizer': str, |
|
'indexes_optimizer_args': dict, |
|
'scheduler': str, |
|
'scheduler_args': dict, |
|
'num_iters': int, |
|
'val_freq': int, |
|
'index_loss_l1_weight': float, |
|
'val_num_sparsities': int, |
|
|
|
'bn_cal_num_iters': int, |
|
|
|
'index_guided_linear_comb_split_size': Or(int, None) |
|
}) |
|
|
|
def upsample_2d_tensor(self, p: torch.Tensor, target_len: int): |
|
assert p.dim() == 2 |
|
return F.upsample(p.unsqueeze(1).unsqueeze(3), |
|
size=(target_len, 1), |
|
mode='bilinear').squeeze(3).squeeze(1) |
|
|
|
def two_params_diff_fast(self, trained_p: torch.Tensor, ref_p: torch.Tensor, |
|
index: torch.Tensor, |
|
split_size: int): |
|
|
|
assert trained_p.dim() == ref_p.dim() |
|
assert index.size(0) == trained_p.size(0) and index.size(1) == ref_p.size(0) |
|
|
|
|
|
|
|
ref_p = ref_p.detach() |
|
if trained_p.dim() > 1: |
|
trained_p = trained_p.flatten(1) |
|
ref_p = ref_p.flatten(1) |
|
|
|
|
|
|
|
|
|
|
|
if trained_p.size(1) < ref_p.size(1): |
|
trained_p = self.upsample_2d_tensor(trained_p, ref_p.size(1)) |
|
|
|
index = index.unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
if split_size is None: |
|
|
|
|
|
linear_combed_ref_p = (ref_p.unsqueeze(0) * index).sum(1) |
|
|
|
else: |
|
|
|
linear_combed_ref_p = 0 |
|
|
|
cur_split_size = split_size |
|
while index.size(1) % cur_split_size != 0: |
|
cur_split_size -= 1 |
|
|
|
|
|
for i in range(0, index.size(1), cur_split_size): |
|
|
|
|
|
linear_combed_ref_p += ref_p.unsqueeze(0)[:, i: i + cur_split_size] * index[:, i: i + cur_split_size] |
|
linear_combed_ref_p = linear_combed_ref_p.sum(1) |
|
|
|
diff = (linear_combed_ref_p - trained_p).norm(2) ** 2 |
|
return diff |
|
|
|
def get_index_loss(self, fm, md, indexes, match_fn, split_size): |
|
res = 0. |
|
|
|
for name, p in md.named_parameters(): |
|
if name not in indexes.keys(): |
|
continue |
|
|
|
|
|
|
|
raw_p = match_fn(name, fm) |
|
|
|
|
|
|
|
index = indexes[name] |
|
|
|
|
|
res += self.two_params_diff_fast(p, raw_p, index, split_size) |
|
return res |
|
|
|
def bn_cal(self, model: nn.Module, train_loader, num_iters, device): |
|
has_bn = False |
|
for n, m in model.named_modules(): |
|
if isinstance(m, nn.BatchNorm2d): |
|
has_bn = True |
|
break |
|
|
|
if not has_bn: |
|
return {} |
|
|
|
def bn_calibration_init(m): |
|
""" calculating post-statistics of batch normalization """ |
|
if getattr(m, 'track_running_stats', False): |
|
|
|
m.reset_running_stats() |
|
|
|
m.training = True |
|
|
|
with torch.no_grad(): |
|
model.eval() |
|
model.apply(bn_calibration_init) |
|
for _ in range(num_iters): |
|
x, _ = next(train_loader) |
|
model(x.to(device)) |
|
model.eval() |
|
|
|
bn_stats = {} |
|
for n, m in model.named_modules(): |
|
if isinstance(m, nn.BatchNorm2d): |
|
bn_stats[n] = m |
|
return bn_stats |
|
|
|
def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]: |
|
super().run(scenario, hyps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) |
|
assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) |
|
|
|
|
|
device = self.models['md'].device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elastic_dnn_util = self.models['fm'].get_elastic_dnn_util() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offline_datasets = scenario.get_offline_datasets() |
|
train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) |
|
val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) |
|
train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], |
|
True, None)) |
|
val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], |
|
False, False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indexes = {} |
|
for name, p in self.models['md'].models_dict['main'].named_parameters(): |
|
if p.dim() > 1: |
|
matched_p_in_fm = self.models['md'].get_matched_param_of_fm(name, self.models['fm'].models_dict['main']) |
|
if matched_p_in_fm is None: |
|
continue |
|
indexes[name] = torch.zeros((p.size(0), matched_p_in_fm.size(0))).to(device) |
|
indexes[name].requires_grad = True |
|
|
|
logger.info(f'construct index in layer {name}') |
|
|
|
tmp_indexes_file_path = os.path.join(self.res_save_dir, 'tmp-indexes.pt') |
|
torch.save(indexes, tmp_indexes_file_path) |
|
logger.info(f'generate indexes ({(os.path.getsize(tmp_indexes_file_path) / 1024**2):.3f}MB)') |
|
os.remove(tmp_indexes_file_path) |
|
|
|
optimizer = torch.optim.__dict__[hyps['optimizer']]([ |
|
|
|
{'params': [v for v in indexes.values()], **hyps['indexes_optimizer_args']} |
|
]) |
|
scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) |
|
tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) |
|
pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) |
|
best_avg_val_acc = 0. |
|
|
|
for p in self.models['md'].models_dict['main'].parameters(): |
|
p.requires_grad = False |
|
for p in self.models['fm'].models_dict['main'].parameters(): |
|
p.requires_grad = False |
|
|
|
for iter_index in pbar: |
|
self.models['md'].to_eval_mode() |
|
self.models['fm'].to_eval_mode() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index_loss = self.get_index_loss(self.models['fm'].models_dict['main'], |
|
self.models['md'].models_dict['main'], |
|
indexes, |
|
self.models['md'].get_matched_param_of_fm, |
|
hyps['index_guided_linear_comb_split_size']) |
|
index_l1_loss = hyps['index_loss_l1_weight'] * torch.FloatTensor([v.abs().sum() for v in indexes.values()]).sum() |
|
|
|
total_loss = index_loss + index_l1_loss |
|
|
|
optimizer.zero_grad() |
|
total_loss.backward() |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
if (iter_index + 1) % (10) == 0: |
|
|
|
|
|
os.makedirs(os.path.join(self.res_save_dir, 'index_hist'), exist_ok=True) |
|
with torch.no_grad(): |
|
for p_name, index in indexes.items(): |
|
index_hist = index.view(-1).histc(bins=20).detach().cpu().numpy() |
|
plt.bar(list(range(20)), index_hist) |
|
plt.savefig(os.path.join(self.res_save_dir, f'index_hist/{p_name}.png')) |
|
plt.clf() |
|
|
|
if (iter_index + 1) % hyps['val_freq'] == 0: |
|
|
|
elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main']) |
|
|
|
cur_md = self.models['md'].models_dict['main'] |
|
md_for_test = deepcopy(self.models['md'].models_dict['main']) |
|
val_accs = {} |
|
avg_val_acc = 0. |
|
bn_stats = {} |
|
|
|
for val_sparsity in [0.0, 0.2, 0.4, 0.8]: |
|
elastic_dnn_util.set_master_dnn_sparsity(md_for_test, val_sparsity) |
|
bn_stats[f'{val_sparsity:.4f}'] = self.bn_cal(md_for_test, train_loader, hyps['bn_cal_num_iters'], device) |
|
self.models['md'].models_dict['main'] = md_for_test |
|
self.models['md'].to_eval_mode() |
|
val_acc = self.models['md'].get_accuracy(val_loader) |
|
|
|
val_accs[f'{val_sparsity:.4f}'] = val_acc |
|
avg_val_acc += val_acc |
|
avg_val_acc /= hyps['val_num_sparsities'] |
|
|
|
self.models['md'].models_dict['main'] = cur_md |
|
self.models['md'].models_dict['indexes'] = indexes |
|
self.models['md'].models_dict['bn_stats'] = bn_stats |
|
self.models['fm'].models_dict['indexes'] = indexes |
|
|
|
self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt')) |
|
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) |
|
|
|
if avg_val_acc > best_avg_val_acc: |
|
best_avg_val_acc = avg_val_acc |
|
self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt')) |
|
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) |
|
|
|
tb_writer.add_scalars(f'losses', dict(index=index_loss, index_l1=index_l1_loss, total=total_loss), iter_index) |
|
pbar.set_description(f'loss: {total_loss:.6f}, index_loss: {index_loss:.6f}, index_l1_loss: {index_l1_loss:.6f}') |
|
if (iter_index + 1) >= hyps['val_freq']: |
|
tb_writer.add_scalars(f'accs/val_accs', val_accs, iter_index) |
|
tb_writer.add_scalar(f'accs/avg_val_acc', avg_val_acc, iter_index) |
|
pbar.set_description(f'loss: {total_loss:.6f}, index_loss: {index_loss:.6f}, index_l1_loss: {index_l1_loss:.6f}, ' |
|
f'avg_val_acc: {avg_val_acc:.4f}') |
|
|