Doven
update code.
f7009b3
raw
history blame
2.5 kB
import sys, os, json
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1])
sys.path.append(root)
os.chdir(root)
# torch
import torch
from torch import nn
# father
from workspace.condition import generalization as item
train_set = item.train_set
test_set = item.test_set
test_set.set_infinite_dataset(max_num=test_set.real_length)
print("num_generated:", test_set.real_length)
config = item.config
model = item.model
assert config.get("tag") is not None, "Remember to set a tag."
generate_config = {
"device": "cuda",
"checkpoint": f"./checkpoint/{config['tag']}.pth",
"generated_path": os.path.join(test_set.generated_path.rsplit("/", 1)[0], "generated_{}_{}.pth"),
"test_command": os.path.join(test_set.test_command.rsplit("/", 1)[0], "generated_{}_{}.pth"),
"need_test": True,
"specific_item": None,
}
config.update(generate_config)
# Model
print('==> Building model..')
diction = torch.load(config["checkpoint"])
permutation_shape = diction["to_permutation_state.weight"].shape
model.to_permutation_state = nn.Embedding(*permutation_shape)
model.load_state_dict(diction)
model = model.to(config["device"])
# generate
print('==> Defining generate..')
def generate(save_path=config["generated_path"], test_command=config["test_command"], need_test=True, index=None):
print("\n==> Generating..")
model.eval()
_, condition = test_set[index]
class_index = str(int("".join([str(int(i)) for i in condition]), 2)).zfill(4)
with torch.no_grad():
prediction = model(sample=True, condition=condition[None], permutation_state=False)
generated_norm = torch.nanmean((prediction.cpu()).abs())
print("Generated_norm:", generated_norm.item())
train_set.save_params(prediction, save_path=save_path.format(config["tag"], f"class{class_index}"))
if need_test:
os.system(test_command.format(config["tag"], f"class{class_index}"))
model.train()
return prediction
if __name__ == "__main__":
for i in range(len(test_set)):
if config["specific_item"] is not None:
assert isinstance(config["specific_item"], int)
i = config["specific_item"]
print("Save to", config["generated_path"].format(config["tag"], "classXXX"))
generate(
save_path=config["generated_path"],
test_command=config["test_command"],
need_test=config["need_test"],
index=i,
)