File size: 4,060 Bytes
ae9c12a
bd2d7df
2e3372e
ae9c12a
 
 
02c5426
 
 
 
 
 
 
 
 
f214b9c
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2799ebc
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
2e3372e
02c5426
 
2e3372e
02c5426
2e3372e
1e37fa7
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
os.system('mim install mmcv')
import numpy as np
import models
import gradio as gr

import torch
from torchvision import transforms
from torchvision.transforms import InterpolationMode

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


def construct_sample(img, mean=0.5, std=0.5):
    img = transforms.ToTensor()(img)
    img = transforms.Resize((48, 48), InterpolationMode.BICUBIC)(img)
    img = transforms.Normalize(mean, std)(img)
    return img

def build_model(cp):
    model_spec = torch.load(cp, map_location='cpu')['model']
    print(model_spec['args'])
    model = models.make(model_spec, load_sd=True).to(device)
    return model


# Function for building extraction
def sr_func(img, cp, scale):
    if cp == 'UC':
        checkpoint = 'pretrain/UC_FunSR_RDN.pth'
    elif cp == 'AID':
        checkpoint = 'pretrain/AID_FunSR_RDN.pth'
    else:
        raise NotImplementedError
    sample = construct_sample(img)
    print('Use: ', device)
    model = build_model(checkpoint)
    model.eval()
    sample = sample.to(device)
    sample = sample.unsqueeze(0)

    ori_size = torch.tensor(sample.shape[2:])  # BCHW
    target_size = ori_size * scale
    target_size = target_size.long()
    lr_target_size_img = torch.nn.functional.interpolate(sample, scale_factor=scale, mode='nearest')
    with torch.no_grad():
        pred = model(sample, target_size.tolist())

    if isinstance(pred, list):
        pred = pred[-1]
    pred = pred * 0.5 + 0.5

    pred *= 255
    pred = pred[0].detach().cpu()
    lr_target_size_img = lr_target_size_img * 0.5 + 0.5
    lr_target_size_img = 255 * lr_target_size_img[0].detach().cpu()

    lr_target_size_img = torch.clamp(lr_target_size_img, 0, 255).permute(1,2,0).numpy().astype(np.uint8)
    pred = torch.clamp(pred, 0, 255).permute(1,2,0).numpy().astype(np.uint8)

    line = np.ones((pred.shape[0], 5, 3), dtype=np.uint8) * 255
    pred = np.concatenate((lr_target_size_img, line, pred), axis=1)
    return pred

title = "FunSR"
description = "Gradio demo for continuous remote sensing image super-resolution. Upload image from UCMerced or AID Dataset or click any one of the examples, " \
              "Then change the upscaling magnification, and click \"Submit\" and wait for the super-resolved result. \n" \
              "Paper: Continuous Remote Sensing Image Super-Resolution based on Context Interaction in Implicit Function Space"

article = "<p style='text-align: center'><a href='https://kyanchen.github.io/FunSR/' target='_blank'>FunSR Project " \
          "Page</a></p> "

default_scale = 2.0
examples = [
    ['examples/AID_school_161_LR.png', 'AID', default_scale],
    ['examples/AID_bridge_19_LR.png', 'AID', default_scale],
    ['examples/AID_parking_60_LR.png', 'AID', default_scale],
    ['examples/AID_commercial_32_LR.png', 'AID', default_scale],

    ['examples/UC_airplane95_LR.png', 'UC', default_scale],
    ['examples/UC_freeway35_LR.png', 'UC', default_scale],
    ['examples/UC_storagetanks54_LR.png', 'UC', default_scale],
    ['examples/UC_airplane00_LR.png', 'UC', default_scale],
]

with gr.Blocks() as demo:
    image_input = gr.Image(type='pil', label='Input Img')
    # with gr.Row().style(equal_height=True):
    # image_LR_output = gr.outputs.Image(label='LR Img', type='numpy')
    image_output = gr.Image(label='SR Result', type='numpy')
    with gr.Row():
        checkpoint = gr.Radio(['UC', 'AID'], label='Checkpoint')
        scale = gr.Slider(1, 10, value=4.0, step=0.1, label='scale')

io = gr.Interface(fn=sr_func,
                  inputs=[image_input,
                          checkpoint,
                          scale
                          ],
                  outputs=[
                      # image_LR_output,
                      image_output
                  ],
                  title=title,
                  description=description,
                  article=article,
                  allow_flagging='auto',
                  examples=examples,
                  cache_examples=True,
                  )
io.launch()