HERIUN
add models
591ba45
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
import glob
import os
import warnings
import cv2
import numpy as np
import skimage.io as io
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from .GeoTr import U2NETP, GeoTr
warnings.filterwarnings("ignore")
class GeoTrP(nn.Module):
def __init__(self):
super(GeoTrP, self).__init__()
self.GeoTr = GeoTr()
def forward(self, x):
bm = self.GeoTr(x) # [0]
bm = 2 * (bm / 288) - 1
bm = (bm + 1) / 2 * 2560
bm = F.interpolate(bm, size=(2560, 2560), mode="bilinear", align_corners=True)
return bm
def reload_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location="cuda:0")
print(len(pretrained_dict.keys()))
print(len(pretrained_dict.keys()))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model