""" Simple test script for the trained model """ import os import torch import tiktoken from model import GPTConfig, GPT def test_model(): # Load model ckpt_path = "out-srs/ckpt_000600.pt" print(f"Loading {ckpt_path}...") checkpoint = torch.load(ckpt_path, map_location="mps") gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) # Load weights state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.eval() model.to("mps") print(f"Model loaded! (iteration {checkpoint['iter_num']})") # Test generation enc = tiktoken.get_encoding("gpt2") encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) decode = lambda l: enc.decode(l) prompt = "Hello, how are you?" print(f"\nPrompt: {prompt}") start_ids = encode(prompt) x = torch.tensor(start_ids, dtype=torch.long, device="mps")[None, ...] with torch.no_grad(): y = model.generate(x, 50, temperature=0.8, top_k=200) result = decode(y[0].tolist()) print(f"Generated: {result}") return True if __name__ == "__main__": test_model()