#!/usr/bin/env python3 | |
import sys | |
from collections import OrderedDict | |
import torch | |
# Load and keep backup | |
m_input = torch.load("2_Dense/pytorch_model.bin") | |
torch.save(m_input, "2_Dense/pytorch_model.bin.bak") | |
# Mappings | |
rename = {"layer.weight": "linear.weight"} | |
# Output | |
m_output = OrderedDict() | |
for key, params in m_input.items(): | |
dst = key | |
if key in rename: | |
print(f"Mapping {key} to {rename[key]}", file=sys.stderr) | |
dst = rename[key] | |
m_output[dst] = params | |
torch.save(m_output, "2_Dense/pytorch_model.bin") | |