from collections import OrderedDict | |
import torch.nn as nn | |
def layer_removal( | |
model: nn.Module, | |
layers_to_remove: OrderedDict | |
): | |
""" | |
Generic removal implementation | |
""" | |
for layer_name, layer_idx in layers_to_remove.items(): | |
modules = layer_name.split(".") | |
mod = model | |
for m in modules[:-1]: | |
mod = getattr(mod, m) | |
if layer_idx is None: | |
del getattr(mod, modules[-1]) | |
else: | |
del getattr(mod, modules[-1])[layer_idx] | |