Spaces:
Runtime error
Runtime error
File size: 2,480 Bytes
d59aeff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
import torch.nn as nn
import imp
import numpy as np
class Base(nn.Module):
def __init__(self, stop_threshold):
super().__init__()
self.init_model()
self.num_params()
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
@property
def r(self):
return self.decoder.r.item()
@r.setter
def r(self, value):
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
def init_model(self):
for p in self.parameters():
if p.dim() > 1: nn.init.xavier_uniform_(p)
def finetune_partial(self, whitelist_layers):
self.zero_grad()
for name, child in self.named_children():
if name in whitelist_layers:
print("Trainable Layer: %s" % name)
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
for param in child.parameters():
param.requires_grad = False
def get_step(self):
return self.step.data.item()
def reset_step(self):
# assignment to parameters or buffers is overloaded, updates internal dict entry
self.step = self.step.data.new_tensor(1)
def log(self, path, msg):
with open(path, "a") as f:
print(msg, file=f)
def load(self, path, device, optimizer=None):
# Use device of model params as location for loaded state
checkpoint = torch.load(str(path), map_location=device)
self.load_state_dict(checkpoint["model_state"], strict=False)
if "optimizer_state" in checkpoint and optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer_state"])
def save(self, path, optimizer=None):
if optimizer is not None:
torch.save({
"model_state": self.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, str(path))
else:
torch.save({
"model_state": self.state_dict(),
}, str(path))
def num_params(self, print_out=True):
parameters = filter(lambda p: p.requires_grad, self.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
if print_out:
print("Trainable Parameters: %.3fM" % parameters)
return parameters
|