ShortGpt / short_gpt /layer_removal.py
deepnet's picture
Upload folder using huggingface_hub
821537b verified
raw
history blame contribute delete
523 Bytes
from collections import OrderedDict
import torch.nn as nn
def layer_removal(
model: nn.Module,
layers_to_remove: OrderedDict
):
"""
Generic removal implementation
"""
for layer_name, layer_idx in layers_to_remove.items():
modules = layer_name.split(".")
mod = model
for m in modules[:-1]:
mod = getattr(mod, m)
if layer_idx is None:
del getattr(mod, modules[-1])
else:
del getattr(mod, modules[-1])[layer_idx]