Spaces:
Runtime error
Runtime error
File size: 2,995 Bytes
240c20c 2718a79 240c20c 2718a79 240c20c 2718a79 b7546a7 240c20c 2718a79 240c20c b79aac9 2718a79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
import warnings
import gradio as gr
from model import DocGeoNet
from seg import U2NETP
import glob
warnings.filterwarnings('ignore')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.msk = U2NETP(3, 1)
self.DocTr = DocGeoNet()
def forward(self, x):
msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
msk = (msk > 0.5).float()
x = msk * x
_, _, bm = self.DocTr(x)
bm = (2 * (bm / 255.) - 1) * 0.99
return bm
def reload_seg_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def reload_rec_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def rec(input_image):
seg_model_path = './model_pretrained/preprocess.pth'
rec_model_path = './model_pretrained/DocGeoNet.pth'
net = Net()
reload_rec_model(net.DocTr, rec_model_path)
reload_seg_model(net.msk, seg_model_path)
net.eval()
im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
h, w, _ = im_ori.shape
im = cv2.resize(im_ori, (256, 256))
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
with torch.no_grad():
bm = net(im)
bm = bm.cpu()
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
# Convert from BGR to RGB
img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
return Image.fromarray(img_rec)
demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorted/*.[pP][nN][gG]')
# Gradio Interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')
iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)
#iface.launch(server_port=8821, server_name="0.0.0.0")
iface.launch() |