import argparse
import os

import numpy as np
import onnx
import onnxruntime
import torch
from monai.networks.nets import FlexibleUNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model_and_export(
    modelname, outname, out_channels, height, width, multigpu=False, in_channels=3, backbone="efficientnet-b0"
):
    """
    Loading a model by name.

    Args:
        modelname: a whole path name of the model that need to be loaded.
        outname: a name for output onnx model.
        out_channels: output channels, which usually equals to 1 + class_number.
        height: input images' height.
        width: input images' width.
        multigpu: if the pre-trained model trained on a multigpu environment.
        in_channels: input images' channel number.
        backbone: a name of backbone used by the flexible unet.
    """
    isopen = os.path.exists(modelname)
    if not isopen:
        raise Exception("The specified model to load does not exist!")

    model = FlexibleUNet(
        in_channels=in_channels,
        out_channels=out_channels,
        backbone=backbone,
        is_pad=False,
        pretrained=False,
        dropout=None,
    )

    if multigpu:
        model = torch.nn.DataParallel(model)
    model = model.cuda()
    model.load_state_dict(torch.load(modelname, map_location=device))  # if the model is trained on multi gpu
    model = model.eval()

    np.random.seed(0)
    x = np.random.random((1, 3, width, height))
    x = torch.tensor(x, dtype=torch.float32)
    x = x.cuda()
    torch_out = model(x)
    input_names = ["INPUT__0"]
    output_names = ["OUTPUT__0"]
    # Export the model
    if multigpu:
        model_trans = model.module
    else:
        model_trans = model
    torch.onnx.export(
        model_trans,  # model to save
        x,  # model input
        outname,  # model save path
        export_params=True,
        verbose=True,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        opset_version=15,
        dynamic_axes={"INPUT__0": {0: "batch_size"}, "OUTPUT__0": {0: "batch_size"}},
    )
    onnx_model = onnx.load(outname)
    onnx.checker.check_model(onnx_model, full_check=True)
    ort_session = onnxruntime.InferenceSession(outname)

    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(["OUTPUT__0"], ort_inputs)
    numpy_torch_out = to_numpy(torch_out)
    # compare ONNX Runtime and PyTorch results
    np.testing.assert_allclose(numpy_torch_out, ort_outs[0], rtol=1e-03, atol=1e-05)
    print("Exported model has been tested with ONNXRuntime, and the result looks good!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # the original model for converting.
    parser.add_argument(
        "--model", type=str, default=r"/workspace/models/model.pt", help="Input an existing model weight"
    )

    # path to save the onnx model.
    parser.add_argument(
        "--outpath", type=str, default=r"/workspace/models/model.onnx", help="A path to save the onnx model."
    )

    parser.add_argument("--width", type=int, default=736, help="Width for exporting onnx model.")

    parser.add_argument("--height", type=int, default=480, help="Height for exporting onnx model.")

    parser.add_argument(
        "--out_channels", type=int, default=2, help="Number of expected out_channels in model for exporting to onnx."
    )

    parser.add_argument("--multigpu", type=bool, default=False, help="If loading model trained with multi gpu.")

    args = parser.parse_args()
    modelname = args.model
    outname = args.outpath
    out_channels = args.out_channels
    height = args.height
    width = args.width
    multigpu = args.multigpu

    if os.path.exists(outname):
        raise Exception(
            "The specified outpath already exists! Change the outpath to avoid overwriting your saved model. "
        )
    model = load_model_and_export(modelname, outname, out_channels, height, width, multigpu)