import torch import torch.onnx as onnx from basicsr.archs.rrdbnet_arch import RRDBNet # Load the PyTorch model device = torch.device('cpu') model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) # Load the state dictionary state_dict = torch.load('Real-ESRGAN_x2plus.pth', map_location=device) # Load the state dictionary model.load_state_dict(state_dict['params_ema']) model.train(False) # Set the model to evaluation mode model.eval() # Define the input shape input_shape = (1, 3, 64, 64) # batch_size, channels, height, width # Create a dummy input tensor dummy_input = torch.randn(input_shape) # Convert the model to ONNX onnx.export(model, dummy_input, 'Real-ESRGAN_x2plus.onnx', opset_version=11, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})