import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import glob
import os
import warnings

import cv2
import numpy as np
import skimage.io as io
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

from .GeoTr import U2NETP, GeoTr

warnings.filterwarnings("ignore")


class GeoTrP(nn.Module):
    def __init__(self):
        super(GeoTrP, self).__init__()
        self.GeoTr = GeoTr()

    def forward(self, x):
        bm = self.GeoTr(x)  # [0]
        bm = 2 * (bm / 288) - 1

        bm = (bm + 1) / 2 * 2560

        bm = F.interpolate(bm, size=(2560, 2560), mode="bilinear", align_corners=True)

        return bm


def reload_model(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location="cuda:0")
        print(len(pretrained_dict.keys()))
        print(len(pretrained_dict.keys()))
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        return model