This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gradio/certificate.pem +31 -0
- __pycache__/inference.cpython-310.pyc +0 -0
- app.py +42 -4
- ckpt/epoch_287.pth +3 -0
- inference.py +59 -0
- text_net/DGRN.py +232 -0
- text_net/__pycache__/DGRN.cpython-310.pyc +0 -0
- text_net/__pycache__/DGRN.cpython-38.pyc +0 -0
- text_net/__pycache__/deform_conv.cpython-310.pyc +0 -0
- text_net/__pycache__/deform_conv.cpython-36.pyc +0 -0
- text_net/__pycache__/deform_conv.cpython-38.pyc +0 -0
- text_net/__pycache__/encoder.cpython-310.pyc +0 -0
- text_net/__pycache__/encoder.cpython-36.pyc +0 -0
- text_net/__pycache__/encoder.cpython-38.pyc +0 -0
- text_net/__pycache__/moco.cpython-310.pyc +0 -0
- text_net/__pycache__/moco.cpython-36.pyc +0 -0
- text_net/__pycache__/moco.cpython-38.pyc +0 -0
- text_net/__pycache__/model.cpython-310.pyc +0 -0
- text_net/__pycache__/model.cpython-36.pyc +0 -0
- text_net/__pycache__/model.cpython-38.pyc +0 -0
- text_net/deform_conv.py +65 -0
- text_net/encoder.py +67 -0
- text_net/moco.py +166 -0
- text_net/model.py +29 -0
- utils/.DS_Store +0 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-36.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/dataset_utils.cpython-310.pyc +0 -0
- utils/__pycache__/dataset_utils.cpython-36.pyc +0 -0
- utils/__pycache__/dataset_utils.cpython-38.pyc +0 -0
- utils/__pycache__/dataset_utils_CDD.cpython-310.pyc +0 -0
- utils/__pycache__/degradation_utils.cpython-310.pyc +0 -0
- utils/__pycache__/degradation_utils.cpython-36.pyc +0 -0
- utils/__pycache__/degradation_utils.cpython-38.pyc +0 -0
- utils/__pycache__/image_io.cpython-310.pyc +0 -0
- utils/__pycache__/image_io.cpython-36.pyc +0 -0
- utils/__pycache__/image_io.cpython-38.pyc +0 -0
- utils/__pycache__/image_utils.cpython-310.pyc +0 -0
- utils/__pycache__/image_utils.cpython-36.pyc +0 -0
- utils/__pycache__/image_utils.cpython-38.pyc +0 -0
- utils/__pycache__/imresize.cpython-36.pyc +0 -0
- utils/__pycache__/imresize.cpython-38.pyc +0 -0
- utils/__pycache__/loss_utils.cpython-38.pyc +0 -0
- utils/__pycache__/val_utils.cpython-310.pyc +0 -0
- utils/__pycache__/val_utils.cpython-36.pyc +0 -0
- utils/__pycache__/val_utils.cpython-38.pyc +0 -0
- utils/dataset_utils.py +309 -0
.gitattributes
CHANGED
|
@@ -19,6 +19,7 @@
|
|
| 19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
__pycache__/inference.cpython-310.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,7 +1,45 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
|
| 3 |
-
def greet(
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from inference import infer
|
| 3 |
|
| 4 |
+
def greet(image, prompt):
|
| 5 |
+
restore_img = infer(img=image, text_prompt=prompt)
|
| 6 |
+
return restore_img
|
| 7 |
|
| 8 |
+
|
| 9 |
+
title = "🖼️ ICDR 🖼️"
|
| 10 |
+
description = ''' ## ICDR: Image Restoration Framework for Composite Degradation following Human Instructions
|
| 11 |
+
Our Github : https://github.com/
|
| 12 |
+
|
| 13 |
+
Siwon Kim, Donghyeon Yoon
|
| 14 |
+
|
| 15 |
+
Ajou Univ
|
| 16 |
+
'''
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
article = "<p style='text-align: center'><a href='https://github.com/' target='_blank'>ICDR</a></p>"
|
| 20 |
+
|
| 21 |
+
#### Image,Prompts examples
|
| 22 |
+
examples = [['input/00010.png', "I love this photo, could you remove the haze and more brighter?"],
|
| 23 |
+
['input/00058.png', "I have to post an emotional shot on Instagram, but it was shot too foggy and too dark. Change it like a sunny day and brighten it up!"]]
|
| 24 |
+
|
| 25 |
+
css = """
|
| 26 |
+
.image-frame img, .image-container img {
|
| 27 |
+
width: auto;
|
| 28 |
+
height: auto;
|
| 29 |
+
max-width: none;
|
| 30 |
+
}
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
demo = gr.Interface(
|
| 35 |
+
fn=greet,
|
| 36 |
+
inputs=[gr.Image(type="pil", label="Input"),
|
| 37 |
+
gr.Text(label="Prompt") ],
|
| 38 |
+
outputs=[gr.Image(type="pil", label="Ouput")],
|
| 39 |
+
title=title,
|
| 40 |
+
description=description,
|
| 41 |
+
article=article,
|
| 42 |
+
examples=examples,
|
| 43 |
+
css=css,
|
| 44 |
+
)
|
| 45 |
+
demo.launch(share=True)
|
ckpt/epoch_287.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db279692728bd4614759c08a0478d9d07200768e5fb7fa893e78aaa05f3ca707
|
| 3 |
+
size 48705338
|
inference.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import subprocess
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from utils.dataset_utils_CDD import DerainDehazeDataset
|
| 10 |
+
from utils.val_utils import AverageMeter, compute_psnr_ssim
|
| 11 |
+
from utils.image_io import save_image_tensor
|
| 12 |
+
|
| 13 |
+
from text_net.model import AirNet
|
| 14 |
+
|
| 15 |
+
def test_Derain_Dehaze(opt, net, dataset, task="derain"):
|
| 16 |
+
output_path = opt.output_path + task + '/'
|
| 17 |
+
subprocess.check_output(['mkdir', '-p', output_path])
|
| 18 |
+
|
| 19 |
+
# dataset.set_dataset(task)
|
| 20 |
+
testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
|
| 21 |
+
print(len(testloader))
|
| 22 |
+
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
for ([degraded_name], degradation, degrad_patch, clean_patch, text_prompt) in tqdm(testloader):
|
| 25 |
+
degrad_patch, clean_patch = degrad_patch.cuda(), clean_patch.cuda()
|
| 26 |
+
restored = net(x_query=degrad_patch, x_key=degrad_patch, text_prompt = text_prompt)
|
| 27 |
+
|
| 28 |
+
return save_image_tensor(restored)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def infer(text_prompt = "", img=None):
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
# Input Parameters
|
| 34 |
+
parser.add_argument('--cuda', type=int, default=0)
|
| 35 |
+
parser.add_argument('--derain_path', type=str, default="data/Test_prompting/", help='save path of test raining images')
|
| 36 |
+
parser.add_argument('--output_path', type=str, default="output/demo11", help='output save path')
|
| 37 |
+
parser.add_argument('--ckpt_path', type=str, default="ckpt/epoch_287.pth", help='checkpoint save path')
|
| 38 |
+
# parser.add_argument('--text_prompt', type=str, default="derain")
|
| 39 |
+
|
| 40 |
+
opt = parser.parse_args()
|
| 41 |
+
# opt.text_prompt = text_prompt
|
| 42 |
+
|
| 43 |
+
np.random.seed(0)
|
| 44 |
+
torch.manual_seed(0)
|
| 45 |
+
torch.cuda.set_device(opt.cuda)
|
| 46 |
+
|
| 47 |
+
opt.batch_size = 7
|
| 48 |
+
ckpt_path = opt.ckpt_path
|
| 49 |
+
|
| 50 |
+
derain_set = DerainDehazeDataset(opt, img=img, text_prompt = text_prompt)
|
| 51 |
+
|
| 52 |
+
# Make network
|
| 53 |
+
net = AirNet(opt).cuda()
|
| 54 |
+
net.eval()
|
| 55 |
+
net.load_state_dict(torch.load(ckpt_path, map_location=torch.device(opt.cuda)))
|
| 56 |
+
|
| 57 |
+
restored = test_Derain_Dehaze(opt, net, derain_set, task="derain")
|
| 58 |
+
|
| 59 |
+
return restored
|
text_net/DGRN.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from .deform_conv import DCN_layer
|
| 4 |
+
import clip
|
| 5 |
+
|
| 6 |
+
clip_model, preprocess = clip.load("ViT-B/32", device='cuda')
|
| 7 |
+
|
| 8 |
+
# 동적으로 텍스트 임베딩 차원 가져오기
|
| 9 |
+
text_embed_dim = clip_model.text_projection.shape[1]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
| 13 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DGM(nn.Module):
|
| 17 |
+
def __init__(self, channels_in, channels_out, kernel_size):
|
| 18 |
+
super(DGM, self).__init__()
|
| 19 |
+
self.channels_out = channels_out
|
| 20 |
+
self.channels_in = channels_in
|
| 21 |
+
self.kernel_size = kernel_size
|
| 22 |
+
|
| 23 |
+
self.dcn = DCN_layer(self.channels_in, self.channels_out, kernel_size,
|
| 24 |
+
padding=(kernel_size - 1) // 2, bias=False)
|
| 25 |
+
self.sft = SFT_layer(self.channels_in, self.channels_out)
|
| 26 |
+
|
| 27 |
+
self.relu = nn.LeakyReLU(0.1, True)
|
| 28 |
+
|
| 29 |
+
def forward(self, x, inter, text_prompt):
|
| 30 |
+
'''
|
| 31 |
+
:param x: feature map: B * C * H * W
|
| 32 |
+
:inter: degradation map: B * C * H * W
|
| 33 |
+
'''
|
| 34 |
+
dcn_out = self.dcn(x, inter)
|
| 35 |
+
sft_out = self.sft(x, inter, text_prompt)
|
| 36 |
+
out = dcn_out + sft_out
|
| 37 |
+
out = x + out
|
| 38 |
+
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
# Projection Head 정의
|
| 42 |
+
class TextProjectionHead(nn.Module):
|
| 43 |
+
def __init__(self, input_dim, output_dim):
|
| 44 |
+
super(TextProjectionHead, self).__init__()
|
| 45 |
+
self.proj = nn.Sequential(
|
| 46 |
+
nn.Linear(input_dim, output_dim),
|
| 47 |
+
nn.ReLU(),
|
| 48 |
+
nn.Linear(output_dim, output_dim)
|
| 49 |
+
).float()
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return self.proj(x.float())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SFT_layer(nn.Module):
|
| 57 |
+
def __init__(self, channels_in, channels_out):
|
| 58 |
+
super(SFT_layer, self).__init__()
|
| 59 |
+
self.conv_gamma = nn.Sequential(
|
| 60 |
+
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False),
|
| 61 |
+
nn.LeakyReLU(0.1, True),
|
| 62 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
| 63 |
+
)
|
| 64 |
+
self.conv_beta = nn.Sequential(
|
| 65 |
+
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False),
|
| 66 |
+
nn.LeakyReLU(0.1, True),
|
| 67 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.text_proj_head = TextProjectionHead(text_embed_dim, channels_out)
|
| 71 |
+
|
| 72 |
+
'''
|
| 73 |
+
self.text_gamma = nn.Sequential(
|
| 74 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
| 75 |
+
nn.LeakyReLU(0.1, True),
|
| 76 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
| 77 |
+
).float()
|
| 78 |
+
self.text_beta = nn.Sequential(
|
| 79 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
| 80 |
+
nn.LeakyReLU(0.1, True),
|
| 81 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
| 82 |
+
).float()
|
| 83 |
+
'''
|
| 84 |
+
|
| 85 |
+
self.cross_attention = nn.MultiheadAttention(embed_dim=channels_out, num_heads=2)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def forward(self, x, inter, text_prompt):
|
| 89 |
+
'''
|
| 90 |
+
:param x: degradation representation: B * C
|
| 91 |
+
:param inter: degradation intermediate representation map: B * C * H * W
|
| 92 |
+
'''
|
| 93 |
+
# img_gamma = self.conv_gamma(inter)
|
| 94 |
+
# img_beta = self.conv_beta(inter)
|
| 95 |
+
|
| 96 |
+
B, C, H, W = inter.shape #cross attention
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
text_tokens = clip.tokenize(text_prompt).to(x.device) # Tokenize the text prompts (Batch size)
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
text_embed = clip_model.encode_text(text_tokens)
|
| 102 |
+
|
| 103 |
+
text_proj = self.text_proj_head(text_embed).float()
|
| 104 |
+
|
| 105 |
+
# 텍스트 임베딩 차원 확장: (B, C, H, W)로 변경 #concat
|
| 106 |
+
# text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, self.conv_gamma[0].out_channels, H, W)
|
| 107 |
+
text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W)
|
| 108 |
+
|
| 109 |
+
# 이미지 중간 표현과 텍스트 임베딩 결합 (concat)
|
| 110 |
+
combined = inter * text_proj_expanded
|
| 111 |
+
# combined = torch.cat([inter, text_proj_expanded], dim=1)
|
| 112 |
+
|
| 113 |
+
# 이미지와 텍스트 기반 gamma와 beta 계산
|
| 114 |
+
img_gamma = self.conv_gamma(combined)
|
| 115 |
+
img_beta = self.conv_beta(combined)
|
| 116 |
+
|
| 117 |
+
''' simple concat
|
| 118 |
+
text_gamma = self.text_gamma(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W)
|
| 119 |
+
text_beta = self.text_beta(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W)
|
| 120 |
+
'''
|
| 121 |
+
|
| 122 |
+
'''
|
| 123 |
+
text_proj = text_proj.unsqueeze(1).expand(-1, H*W, -1) # B * (H*W) * C
|
| 124 |
+
|
| 125 |
+
# 이미지 중간 표현 변환: B * (H*W) * C로 변경
|
| 126 |
+
inter_flat = inter.view(B, C, -1).permute(2, 0, 1) # (H*W) * B * C
|
| 127 |
+
|
| 128 |
+
# Cross-attention 적용
|
| 129 |
+
attn_output, _ = self.cross_attention(text_proj.permute(1, 0, 2), inter_flat, inter_flat)
|
| 130 |
+
attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) # B * C * H * W
|
| 131 |
+
|
| 132 |
+
# Gamma와 Beta 계산
|
| 133 |
+
img_gamma = self.conv_gamma(attn_output)
|
| 134 |
+
img_beta = self.conv_beta(attn_output)
|
| 135 |
+
'''
|
| 136 |
+
# concat으로 text 결합 실험
|
| 137 |
+
return x * img_gamma + img_beta
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DGB(nn.Module):
|
| 141 |
+
def __init__(self, conv, n_feat, kernel_size):
|
| 142 |
+
super(DGB, self).__init__()
|
| 143 |
+
|
| 144 |
+
# self.da_conv1 = DGM(n_feat, n_feat, kernel_size)
|
| 145 |
+
# self.da_conv2 = DGM(n_feat, n_feat, kernel_size)
|
| 146 |
+
self.dgm1 = DGM(n_feat, n_feat, kernel_size)
|
| 147 |
+
self.dgm2 = DGM(n_feat, n_feat, kernel_size)
|
| 148 |
+
self.conv1 = conv(n_feat, n_feat, kernel_size)
|
| 149 |
+
self.conv2 = conv(n_feat, n_feat, kernel_size)
|
| 150 |
+
|
| 151 |
+
self.relu = nn.LeakyReLU(0.1, True)
|
| 152 |
+
|
| 153 |
+
def forward(self, x, inter, text_prompt):
|
| 154 |
+
'''
|
| 155 |
+
:param x: feature map: B * C * H * W
|
| 156 |
+
:param inter: degradation representation: B * C * H * W
|
| 157 |
+
'''
|
| 158 |
+
|
| 159 |
+
out = self.relu(self.dgm1(x, inter, text_prompt))
|
| 160 |
+
out = self.relu(self.conv1(out))
|
| 161 |
+
out = self.relu(self.dgm2(out, inter, text_prompt))
|
| 162 |
+
out = self.conv2(out) + x
|
| 163 |
+
|
| 164 |
+
return out
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DGG(nn.Module):
|
| 168 |
+
def __init__(self, conv, n_feat, kernel_size, n_blocks):
|
| 169 |
+
super(DGG, self).__init__()
|
| 170 |
+
self.n_blocks = n_blocks
|
| 171 |
+
modules_body = [
|
| 172 |
+
DGB(conv, n_feat, kernel_size) \
|
| 173 |
+
for _ in range(n_blocks)
|
| 174 |
+
]
|
| 175 |
+
modules_body.append(conv(n_feat, n_feat, kernel_size))
|
| 176 |
+
|
| 177 |
+
self.body = nn.Sequential(*modules_body)
|
| 178 |
+
|
| 179 |
+
def forward(self, x, inter, text_prompt):
|
| 180 |
+
'''
|
| 181 |
+
:param x: feature map: B * C * H * W
|
| 182 |
+
:param inter: degradation representation: B * C * H * W
|
| 183 |
+
'''
|
| 184 |
+
res = x
|
| 185 |
+
for i in range(self.n_blocks):
|
| 186 |
+
res = self.body[i](res, inter, text_prompt)
|
| 187 |
+
res = self.body[-1](res)
|
| 188 |
+
res = res + x
|
| 189 |
+
|
| 190 |
+
return res
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class DGRN(nn.Module):
|
| 194 |
+
def __init__(self, opt, conv=default_conv):
|
| 195 |
+
super(DGRN, self).__init__()
|
| 196 |
+
|
| 197 |
+
self.n_groups = 5
|
| 198 |
+
n_blocks = 5
|
| 199 |
+
n_feats = 64
|
| 200 |
+
kernel_size = 3
|
| 201 |
+
|
| 202 |
+
# head module
|
| 203 |
+
modules_head = [conv(3, n_feats, kernel_size)]
|
| 204 |
+
self.head = nn.Sequential(*modules_head)
|
| 205 |
+
|
| 206 |
+
# body
|
| 207 |
+
modules_body = [
|
| 208 |
+
DGG(default_conv, n_feats, kernel_size, n_blocks) \
|
| 209 |
+
for _ in range(self.n_groups)
|
| 210 |
+
]
|
| 211 |
+
modules_body.append(conv(n_feats, n_feats, kernel_size))
|
| 212 |
+
self.body = nn.Sequential(*modules_body)
|
| 213 |
+
|
| 214 |
+
# tail
|
| 215 |
+
modules_tail = [conv(n_feats, 3, kernel_size)]
|
| 216 |
+
self.tail = nn.Sequential(*modules_tail)
|
| 217 |
+
|
| 218 |
+
def forward(self, x, inter, text_prompt):
|
| 219 |
+
# head
|
| 220 |
+
x = self.head(x)
|
| 221 |
+
|
| 222 |
+
# body
|
| 223 |
+
res = x
|
| 224 |
+
for i in range(self.n_groups):
|
| 225 |
+
res = self.body[i](res, inter, text_prompt)
|
| 226 |
+
res = self.body[-1](res)
|
| 227 |
+
res = res + x
|
| 228 |
+
|
| 229 |
+
# tail
|
| 230 |
+
x = self.tail(res)
|
| 231 |
+
|
| 232 |
+
return x
|
text_net/__pycache__/DGRN.cpython-310.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
text_net/__pycache__/DGRN.cpython-38.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
text_net/__pycache__/deform_conv.cpython-310.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
text_net/__pycache__/deform_conv.cpython-36.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
text_net/__pycache__/deform_conv.cpython-38.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
text_net/__pycache__/encoder.cpython-310.pyc
ADDED
|
Binary file (2.33 kB). View file
|
|
|
text_net/__pycache__/encoder.cpython-36.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
text_net/__pycache__/encoder.cpython-38.pyc
ADDED
|
Binary file (2.36 kB). View file
|
|
|
text_net/__pycache__/moco.cpython-310.pyc
ADDED
|
Binary file (4.43 kB). View file
|
|
|
text_net/__pycache__/moco.cpython-36.pyc
ADDED
|
Binary file (4.39 kB). View file
|
|
|
text_net/__pycache__/moco.cpython-38.pyc
ADDED
|
Binary file (4.43 kB). View file
|
|
|
text_net/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (936 Bytes). View file
|
|
|
text_net/__pycache__/model.cpython-36.pyc
ADDED
|
Binary file (914 Bytes). View file
|
|
|
text_net/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (916 Bytes). View file
|
|
|
text_net/deform_conv.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn.modules.utils import _pair
|
| 6 |
+
|
| 7 |
+
from mmcv.ops import modulated_deform_conv2d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DCN_layer(nn.Module):
|
| 11 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
|
| 12 |
+
groups=1, deformable_groups=1, bias=True, extra_offset_mask=True):
|
| 13 |
+
super(DCN_layer, self).__init__()
|
| 14 |
+
self.in_channels = in_channels
|
| 15 |
+
self.out_channels = out_channels
|
| 16 |
+
self.kernel_size = _pair(kernel_size)
|
| 17 |
+
self.stride = stride
|
| 18 |
+
self.padding = padding
|
| 19 |
+
self.dilation = dilation
|
| 20 |
+
self.groups = groups
|
| 21 |
+
self.deformable_groups = deformable_groups
|
| 22 |
+
self.with_bias = bias
|
| 23 |
+
|
| 24 |
+
self.weight = nn.Parameter(
|
| 25 |
+
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
| 26 |
+
|
| 27 |
+
self.extra_offset_mask = extra_offset_mask
|
| 28 |
+
self.conv_offset_mask = nn.Conv2d(
|
| 29 |
+
self.in_channels * 2,
|
| 30 |
+
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
| 31 |
+
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
|
| 32 |
+
bias=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if bias:
|
| 36 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 37 |
+
else:
|
| 38 |
+
self.register_parameter('bias', None)
|
| 39 |
+
|
| 40 |
+
self.init_offset()
|
| 41 |
+
self.reset_parameters()
|
| 42 |
+
|
| 43 |
+
def reset_parameters(self):
|
| 44 |
+
n = self.in_channels
|
| 45 |
+
for k in self.kernel_size:
|
| 46 |
+
n *= k
|
| 47 |
+
stdv = 1. / math.sqrt(n)
|
| 48 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 49 |
+
if self.bias is not None:
|
| 50 |
+
self.bias.data.zero_()
|
| 51 |
+
|
| 52 |
+
def init_offset(self):
|
| 53 |
+
self.conv_offset_mask.weight.data.zero_()
|
| 54 |
+
self.conv_offset_mask.bias.data.zero_()
|
| 55 |
+
|
| 56 |
+
def forward(self, input_feat, inter):
|
| 57 |
+
feat_degradation = torch.cat([input_feat, inter], dim=1)
|
| 58 |
+
|
| 59 |
+
out = self.conv_offset_mask(feat_degradation)
|
| 60 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
| 61 |
+
offset = torch.cat((o1, o2), dim=1)
|
| 62 |
+
mask = torch.sigmoid(mask)
|
| 63 |
+
|
| 64 |
+
return modulated_deform_conv2d(input_feat.contiguous(), offset, mask, self.weight, self.bias, self.stride,
|
| 65 |
+
self.padding, self.dilation, self.groups, self.deformable_groups)
|
text_net/encoder.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from text_net.moco import MoCo
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ResBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_feat, out_feat, stride=1):
|
| 7 |
+
super(ResBlock, self).__init__()
|
| 8 |
+
self.backbone = nn.Sequential(
|
| 9 |
+
nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=stride, padding=1, bias=False),
|
| 10 |
+
nn.BatchNorm2d(out_feat),
|
| 11 |
+
nn.LeakyReLU(0.1, True),
|
| 12 |
+
nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1, bias=False),
|
| 13 |
+
nn.BatchNorm2d(out_feat),
|
| 14 |
+
)
|
| 15 |
+
self.shortcut = nn.Sequential(
|
| 16 |
+
nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=stride, bias=False),
|
| 17 |
+
nn.BatchNorm2d(out_feat)
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return nn.LeakyReLU(0.1, True)(self.backbone(x) + self.shortcut(x))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ResEncoder(nn.Module):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super(ResEncoder, self).__init__()
|
| 27 |
+
|
| 28 |
+
self.E_pre = ResBlock(in_feat=3, out_feat=64, stride=1)
|
| 29 |
+
self.E = nn.Sequential(
|
| 30 |
+
ResBlock(in_feat=64, out_feat=128, stride=2),
|
| 31 |
+
ResBlock(in_feat=128, out_feat=256, stride=2),
|
| 32 |
+
nn.AdaptiveAvgPool2d(1)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.mlp = nn.Sequential(
|
| 36 |
+
nn.Linear(256, 256),
|
| 37 |
+
nn.LeakyReLU(0.1, True),
|
| 38 |
+
nn.Linear(256, 256),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
inter = self.E_pre(x)
|
| 43 |
+
fea = self.E(inter).squeeze(-1).squeeze(-1)
|
| 44 |
+
out = self.mlp(fea)
|
| 45 |
+
|
| 46 |
+
return fea, out, inter
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CBDE(nn.Module):
|
| 50 |
+
def __init__(self, opt):
|
| 51 |
+
super(CBDE, self).__init__()
|
| 52 |
+
|
| 53 |
+
dim = 256
|
| 54 |
+
|
| 55 |
+
# Encoder
|
| 56 |
+
self.E = MoCo(base_encoder=ResEncoder, dim=dim, K=opt.batch_size * dim)
|
| 57 |
+
|
| 58 |
+
def forward(self, x_query, x_key):
|
| 59 |
+
if self.training:
|
| 60 |
+
# degradation-aware represenetion learning
|
| 61 |
+
fea, logits, labels, inter = self.E(x_query, x_key)
|
| 62 |
+
|
| 63 |
+
return fea, logits, labels, inter
|
| 64 |
+
else:
|
| 65 |
+
# degradation-aware represenetion learning
|
| 66 |
+
fea, inter = self.E(x_query, x_query)
|
| 67 |
+
return fea, inter
|
text_net/moco.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MoCo(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Build a MoCo model with: a query encoder, a key encoder, and a queue
|
| 9 |
+
https://arxiv.org/abs/1911.05722
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, base_encoder, dim=256, K=3*256, m=0.999, T=0.07, mlp=False):
|
| 12 |
+
"""
|
| 13 |
+
dim: feature dimension (default: 128)
|
| 14 |
+
K: queue size; number of negative keys (default: 65536)
|
| 15 |
+
m: moco momentum of updating key encoder (default: 0.999)
|
| 16 |
+
T: softmax temperature (default: 0.07)
|
| 17 |
+
"""
|
| 18 |
+
super(MoCo, self).__init__()
|
| 19 |
+
|
| 20 |
+
self.K = K
|
| 21 |
+
self.m = m
|
| 22 |
+
self.T = T
|
| 23 |
+
|
| 24 |
+
# create the encoders
|
| 25 |
+
# num_classes is the output fc dimension
|
| 26 |
+
self.encoder_q = base_encoder()
|
| 27 |
+
self.encoder_k = base_encoder()
|
| 28 |
+
|
| 29 |
+
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
| 30 |
+
param_k.data.copy_(param_q.data) # initialize
|
| 31 |
+
param_k.requires_grad = False # not update by gradient
|
| 32 |
+
|
| 33 |
+
# create the queue
|
| 34 |
+
self.register_buffer("queue", torch.randn(dim, K))
|
| 35 |
+
self.queue = nn.functional.normalize(self.queue, dim=0)
|
| 36 |
+
|
| 37 |
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def _momentum_update_key_encoder(self):
|
| 41 |
+
"""
|
| 42 |
+
Momentum update of the key encoder
|
| 43 |
+
"""
|
| 44 |
+
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
| 45 |
+
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def _dequeue_and_enqueue(self, keys):
|
| 49 |
+
# gather keys before updating queue
|
| 50 |
+
# keys = concat_all_gather(keys)
|
| 51 |
+
batch_size = keys.shape[0]
|
| 52 |
+
|
| 53 |
+
ptr = int(self.queue_ptr)
|
| 54 |
+
assert self.K % batch_size == 0 # for simplicity
|
| 55 |
+
|
| 56 |
+
# replace the keys at ptr (dequeue and enqueue)
|
| 57 |
+
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
|
| 58 |
+
ptr = (ptr + batch_size) % self.K # move pointer
|
| 59 |
+
|
| 60 |
+
self.queue_ptr[0] = ptr
|
| 61 |
+
|
| 62 |
+
@torch.no_grad()
|
| 63 |
+
def _batch_shuffle_ddp(self, x):
|
| 64 |
+
"""
|
| 65 |
+
Batch shuffle, for making use of BatchNorm.
|
| 66 |
+
*** Only support DistributedDataParallel (DDP) model. ***
|
| 67 |
+
"""
|
| 68 |
+
# gather from all gpus
|
| 69 |
+
batch_size_this = x.shape[0]
|
| 70 |
+
x_gather = concat_all_gather(x)
|
| 71 |
+
batch_size_all = x_gather.shape[0]
|
| 72 |
+
|
| 73 |
+
num_gpus = batch_size_all // batch_size_this
|
| 74 |
+
|
| 75 |
+
# random shuffle index
|
| 76 |
+
idx_shuffle = torch.randperm(batch_size_all).cuda()
|
| 77 |
+
|
| 78 |
+
# broadcast to all gpus
|
| 79 |
+
torch.distributed.broadcast(idx_shuffle, src=0)
|
| 80 |
+
|
| 81 |
+
# index for restoring
|
| 82 |
+
idx_unshuffle = torch.argsort(idx_shuffle)
|
| 83 |
+
|
| 84 |
+
# shuffled index for this gpu
|
| 85 |
+
gpu_idx = torch.distributed.get_rank()
|
| 86 |
+
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
| 87 |
+
|
| 88 |
+
return x_gather[idx_this], idx_unshuffle
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
|
| 92 |
+
"""
|
| 93 |
+
Undo batch shuffle.
|
| 94 |
+
*** Only support DistributedDataParallel (DDP) model. ***
|
| 95 |
+
"""
|
| 96 |
+
# gather from all gpus
|
| 97 |
+
batch_size_this = x.shape[0]
|
| 98 |
+
x_gather = concat_all_gather(x)
|
| 99 |
+
batch_size_all = x_gather.shape[0]
|
| 100 |
+
|
| 101 |
+
num_gpus = batch_size_all // batch_size_this
|
| 102 |
+
|
| 103 |
+
# restored index for this gpu
|
| 104 |
+
gpu_idx = torch.distributed.get_rank()
|
| 105 |
+
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
| 106 |
+
|
| 107 |
+
return x_gather[idx_this]
|
| 108 |
+
|
| 109 |
+
def forward(self, im_q, im_k):
|
| 110 |
+
"""
|
| 111 |
+
Input:
|
| 112 |
+
im_q: a batch of query images
|
| 113 |
+
im_k: a batch of key images
|
| 114 |
+
Output:
|
| 115 |
+
logits, targets
|
| 116 |
+
"""
|
| 117 |
+
if self.training:
|
| 118 |
+
# compute query features
|
| 119 |
+
embedding, q, inter = self.encoder_q(im_q) # queries: NxC
|
| 120 |
+
q = nn.functional.normalize(q, dim=1)
|
| 121 |
+
|
| 122 |
+
# compute key features
|
| 123 |
+
with torch.no_grad(): # no gradient to keys
|
| 124 |
+
self._momentum_update_key_encoder() # update the key encoder
|
| 125 |
+
|
| 126 |
+
_, k, _ = self.encoder_k(im_k) # keys: NxC
|
| 127 |
+
k = nn.functional.normalize(k, dim=1)
|
| 128 |
+
|
| 129 |
+
# compute logits
|
| 130 |
+
# Einstein sum is more intuitive
|
| 131 |
+
# positive logits: Nx1
|
| 132 |
+
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
|
| 133 |
+
# negative logits: NxK
|
| 134 |
+
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
|
| 135 |
+
|
| 136 |
+
# logits: Nx(1+K)
|
| 137 |
+
logits = torch.cat([l_pos, l_neg], dim=1)
|
| 138 |
+
|
| 139 |
+
# apply temperature
|
| 140 |
+
logits /= self.T
|
| 141 |
+
|
| 142 |
+
# labels: positive key indicators
|
| 143 |
+
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
|
| 144 |
+
# dequeue and enqueue
|
| 145 |
+
self._dequeue_and_enqueue(k)
|
| 146 |
+
|
| 147 |
+
return embedding, logits, labels, inter
|
| 148 |
+
else:
|
| 149 |
+
embedding, _, inter = self.encoder_q(im_q)
|
| 150 |
+
|
| 151 |
+
return embedding, inter
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# utils
|
| 155 |
+
@torch.no_grad()
|
| 156 |
+
def concat_all_gather(tensor):
|
| 157 |
+
"""
|
| 158 |
+
Performs all_gather operation on the provided tensors.
|
| 159 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 160 |
+
"""
|
| 161 |
+
tensors_gather = [torch.ones_like(tensor)
|
| 162 |
+
for _ in range(torch.distributed.get_world_size())]
|
| 163 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 164 |
+
|
| 165 |
+
output = torch.cat(tensors_gather, dim=0)
|
| 166 |
+
return output
|
text_net/model.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
from text_net.encoder import CBDE
|
| 4 |
+
from text_net.DGRN import DGRN
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AirNet(nn.Module):
|
| 8 |
+
def __init__(self, opt):
|
| 9 |
+
super(AirNet, self).__init__()
|
| 10 |
+
|
| 11 |
+
# Restorer
|
| 12 |
+
self.R = DGRN(opt)
|
| 13 |
+
|
| 14 |
+
# Encoder
|
| 15 |
+
self.E = CBDE(opt)
|
| 16 |
+
|
| 17 |
+
def forward(self, x_query, x_key, text_prompt):
|
| 18 |
+
if self.training:
|
| 19 |
+
fea, logits, labels, inter = self.E(x_query, x_key)
|
| 20 |
+
|
| 21 |
+
restored = self.R(x_query, inter, text_prompt)
|
| 22 |
+
|
| 23 |
+
return restored, logits, labels
|
| 24 |
+
else:
|
| 25 |
+
fea, inter = self.E(x_query, x_query)
|
| 26 |
+
|
| 27 |
+
restored = self.R(x_query, inter, text_prompt)
|
| 28 |
+
|
| 29 |
+
return restored
|
utils/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (138 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-36.pyc
ADDED
|
Binary file (123 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (148 Bytes). View file
|
|
|
utils/__pycache__/dataset_utils.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
utils/__pycache__/dataset_utils.cpython-36.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
utils/__pycache__/dataset_utils.cpython-38.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
utils/__pycache__/dataset_utils_CDD.cpython-310.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
utils/__pycache__/degradation_utils.cpython-310.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
utils/__pycache__/degradation_utils.cpython-36.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
utils/__pycache__/degradation_utils.cpython-38.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
utils/__pycache__/image_io.cpython-310.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
utils/__pycache__/image_io.cpython-36.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
utils/__pycache__/image_io.cpython-38.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
utils/__pycache__/image_utils.cpython-310.pyc
ADDED
|
Binary file (7.48 kB). View file
|
|
|
utils/__pycache__/image_utils.cpython-36.pyc
ADDED
|
Binary file (7.61 kB). View file
|
|
|
utils/__pycache__/image_utils.cpython-38.pyc
ADDED
|
Binary file (7.46 kB). View file
|
|
|
utils/__pycache__/imresize.cpython-36.pyc
ADDED
|
Binary file (4.75 kB). View file
|
|
|
utils/__pycache__/imresize.cpython-38.pyc
ADDED
|
Binary file (4.75 kB). View file
|
|
|
utils/__pycache__/loss_utils.cpython-38.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
utils/__pycache__/val_utils.cpython-310.pyc
ADDED
|
Binary file (3.35 kB). View file
|
|
|
utils/__pycache__/val_utils.cpython-36.pyc
ADDED
|
Binary file (2.34 kB). View file
|
|
|
utils/__pycache__/val_utils.cpython-38.pyc
ADDED
|
Binary file (3.27 kB). View file
|
|
|
utils/dataset_utils.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import copy
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
|
| 9 |
+
|
| 10 |
+
from utils.image_utils import random_augmentation, crop_img
|
| 11 |
+
from utils.degradation_utils import Degradation
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TrainDataset(Dataset):
|
| 15 |
+
def __init__(self, args):
|
| 16 |
+
super(TrainDataset, self).__init__()
|
| 17 |
+
self.args = args
|
| 18 |
+
self.rs_ids = []
|
| 19 |
+
self.hazy_ids = []
|
| 20 |
+
self.D = Degradation(args)
|
| 21 |
+
self.de_temp = 0
|
| 22 |
+
self.de_type = self.args.de_type
|
| 23 |
+
self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
|
| 24 |
+
|
| 25 |
+
self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4}
|
| 26 |
+
|
| 27 |
+
self._init_ids()
|
| 28 |
+
|
| 29 |
+
self.crop_transform = Compose([
|
| 30 |
+
ToPILImage(),
|
| 31 |
+
RandomCrop(args.patch_size),
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
self.toTensor = ToTensor()
|
| 35 |
+
|
| 36 |
+
def _init_ids(self):
|
| 37 |
+
if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type:
|
| 38 |
+
self._init_clean_ids()
|
| 39 |
+
if 'derain' in self.de_type:
|
| 40 |
+
self._init_rs_ids()
|
| 41 |
+
if 'dehaze' in self.de_type:
|
| 42 |
+
self._init_hazy_ids()
|
| 43 |
+
|
| 44 |
+
random.shuffle(self.de_type)
|
| 45 |
+
|
| 46 |
+
def _init_clean_ids(self):
|
| 47 |
+
clean_ids = []
|
| 48 |
+
# 파일 목록 중 이미지 파일만 필터링
|
| 49 |
+
name_list = os.listdir(self.args.denoise_dir)
|
| 50 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
| 51 |
+
|
| 52 |
+
clean_ids += [self.args.denoise_dir + id_ for id_ in name_list]
|
| 53 |
+
|
| 54 |
+
if 'denoise_15' in self.de_type:
|
| 55 |
+
self.s15_ids = copy.deepcopy(clean_ids)
|
| 56 |
+
random.shuffle(self.s15_ids)
|
| 57 |
+
self.s15_counter = 0
|
| 58 |
+
if 'denoise_25' in self.de_type:
|
| 59 |
+
self.s25_ids = copy.deepcopy(clean_ids)
|
| 60 |
+
random.shuffle(self.s25_ids)
|
| 61 |
+
self.s25_counter = 0
|
| 62 |
+
if 'denoise_50' in self.de_type:
|
| 63 |
+
self.s50_ids = copy.deepcopy(clean_ids)
|
| 64 |
+
random.shuffle(self.s50_ids)
|
| 65 |
+
self.s50_counter = 0
|
| 66 |
+
|
| 67 |
+
# print(clean_ids)
|
| 68 |
+
|
| 69 |
+
self.num_clean = len(clean_ids)
|
| 70 |
+
|
| 71 |
+
def _init_hazy_ids(self):
|
| 72 |
+
# 파일 목록 중 이미지 파일만 필터링
|
| 73 |
+
dehaze_ids = []
|
| 74 |
+
name_list = os.listdir(self.args.dehaze_dir)
|
| 75 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
| 76 |
+
dehaze_ids += [self.args.dehaze_dir + id_ for id_ in name_list]
|
| 77 |
+
self.hazy_ids = dehaze_ids
|
| 78 |
+
|
| 79 |
+
self.hazy_counter = 0
|
| 80 |
+
self.num_hazy = len(self.hazy_ids)
|
| 81 |
+
|
| 82 |
+
def _init_rs_ids(self):
|
| 83 |
+
# 파일 목록 중 이미지 파일만 필터링
|
| 84 |
+
derain_ids = []
|
| 85 |
+
name_list = os.listdir(self.args.derain_dir)
|
| 86 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
| 87 |
+
derain_ids += [self.args.derain_dir + id_ for id_ in name_list]
|
| 88 |
+
self.rs_ids = derain_ids
|
| 89 |
+
|
| 90 |
+
self.rl_counter = 0
|
| 91 |
+
# print(derain_ids)
|
| 92 |
+
|
| 93 |
+
self.num_rl = len(self.rs_ids)
|
| 94 |
+
|
| 95 |
+
def _crop_patch(self, img_1, img_2):
|
| 96 |
+
H = img_1.shape[0]
|
| 97 |
+
W = img_1.shape[1]
|
| 98 |
+
ind_H = random.randint(0, H - self.args.patch_size)
|
| 99 |
+
ind_W = random.randint(0, W - self.args.patch_size)
|
| 100 |
+
|
| 101 |
+
patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
|
| 102 |
+
patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
|
| 103 |
+
|
| 104 |
+
return patch_1, patch_2
|
| 105 |
+
|
| 106 |
+
def _get_gt_name(self, rainy_name):
|
| 107 |
+
gt_name = 'data/' + 'Target/Derain/norain-' + rainy_name.split('rain-')[-1]
|
| 108 |
+
return gt_name
|
| 109 |
+
|
| 110 |
+
def _get_nonhazy_name(self, hazy_name):
|
| 111 |
+
gt_name = 'data/' + 'Target/Dehaze/nohaze-' + rainy_name.split('haze-')[-1]
|
| 112 |
+
return gt_name
|
| 113 |
+
|
| 114 |
+
def __getitem__(self, _):
|
| 115 |
+
de_id = self.de_dict[self.de_type[self.de_temp]]
|
| 116 |
+
|
| 117 |
+
if de_id < 3:
|
| 118 |
+
if de_id == 0:
|
| 119 |
+
clean_id = self.s15_ids[self.s15_counter]
|
| 120 |
+
self.s15_counter = (self.s15_counter + 1) % self.num_clean
|
| 121 |
+
if self.s15_counter == 0:
|
| 122 |
+
random.shuffle(self.s15_ids)
|
| 123 |
+
elif de_id == 1:
|
| 124 |
+
clean_id = self.s25_ids[self.s25_counter]
|
| 125 |
+
self.s25_counter = (self.s25_counter + 1) % self.num_clean
|
| 126 |
+
if self.s25_counter == 0:
|
| 127 |
+
random.shuffle(self.s25_ids)
|
| 128 |
+
elif de_id == 2:
|
| 129 |
+
clean_id = self.s50_ids[self.s50_counter]
|
| 130 |
+
self.s50_counter = (self.s50_counter + 1) % self.num_clean
|
| 131 |
+
if self.s50_counter == 0:
|
| 132 |
+
random.shuffle(self.s50_ids)
|
| 133 |
+
|
| 134 |
+
# clean_id = random.randint(0, len(self.clean_ids) - 1)
|
| 135 |
+
clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)
|
| 136 |
+
clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)
|
| 137 |
+
clean_patch_1, clean_patch_2 = np.array(clean_patch_1), np.array(clean_patch_2)
|
| 138 |
+
|
| 139 |
+
# clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
|
| 140 |
+
clean_name = clean_id.split("/")[-1].split('.')[0]
|
| 141 |
+
|
| 142 |
+
clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)
|
| 143 |
+
degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)
|
| 144 |
+
else:
|
| 145 |
+
if de_id == 3:
|
| 146 |
+
# Rain Streak Removal
|
| 147 |
+
# rl_id = random.randint(0, len(self.rl_ids) - 1)
|
| 148 |
+
degrad_img = crop_img(np.array(Image.open(self.rs_ids[self.rl_counter]).convert('RGB')), base=16)
|
| 149 |
+
clean_name = self._get_gt_name(self.rs_ids[self.rl_counter])
|
| 150 |
+
clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
|
| 151 |
+
|
| 152 |
+
self.rl_counter = (self.rl_counter + 1) % self.num_rl
|
| 153 |
+
if self.rl_counter == 0:
|
| 154 |
+
random.shuffle(self.rs_ids)
|
| 155 |
+
elif de_id == 4:
|
| 156 |
+
# Dehazing with SOTS outdoor training set
|
| 157 |
+
# hazy_id = random.randint(0, len(self.hazy_ids) - 1)
|
| 158 |
+
degrad_img = crop_img(np.array(Image.open(self.hazy_ids[self.hazy_counter]).convert('RGB')), base=16)
|
| 159 |
+
clean_name = self._get_nonhazy_name(self.hazy_ids[self.hazy_counter])
|
| 160 |
+
clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
|
| 161 |
+
|
| 162 |
+
self.hazy_counter = (self.hazy_counter + 1) % self.num_hazy
|
| 163 |
+
if self.hazy_counter == 0:
|
| 164 |
+
random.shuffle(self.hazy_ids)
|
| 165 |
+
degrad_patch_1, clean_patch_1 = random_augmentation(*self._crop_patch(degrad_img, clean_img))
|
| 166 |
+
degrad_patch_2, clean_patch_2 = random_augmentation(*self._crop_patch(degrad_img, clean_img))
|
| 167 |
+
|
| 168 |
+
clean_patch_1, clean_patch_2 = self.toTensor(clean_patch_1), self.toTensor(clean_patch_2)
|
| 169 |
+
degrad_patch_1, degrad_patch_2 = self.toTensor(degrad_patch_1), self.toTensor(degrad_patch_2)
|
| 170 |
+
|
| 171 |
+
self.de_temp = (self.de_temp + 1) % len(self.de_type)
|
| 172 |
+
if self.de_temp == 0:
|
| 173 |
+
random.shuffle(self.de_type)
|
| 174 |
+
|
| 175 |
+
return [clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2
|
| 176 |
+
|
| 177 |
+
def __len__(self):
|
| 178 |
+
return 400 * len(self.args.de_type)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class DenoiseTestDataset(Dataset):
|
| 182 |
+
def __init__(self, args):
|
| 183 |
+
super(DenoiseTestDataset, self).__init__()
|
| 184 |
+
self.args = args
|
| 185 |
+
self.clean_ids = []
|
| 186 |
+
self.sigma = 15
|
| 187 |
+
self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
|
| 188 |
+
|
| 189 |
+
self._init_clean_ids()
|
| 190 |
+
|
| 191 |
+
self.toTensor = ToTensor()
|
| 192 |
+
|
| 193 |
+
def _init_clean_ids(self):
|
| 194 |
+
clean_ids = []
|
| 195 |
+
# 파일 목록 중 이미지 파일만 필터링
|
| 196 |
+
name_list = os.listdir(self.args.denoise_path)
|
| 197 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
| 198 |
+
self.clean_ids += [self.args.denoise_path + id_ for id_ in name_list]
|
| 199 |
+
|
| 200 |
+
self.num_clean = len(self.clean_ids)
|
| 201 |
+
|
| 202 |
+
def _add_gaussian_noise(self, clean_patch):
|
| 203 |
+
noise = np.random.randn(*clean_patch.shape)
|
| 204 |
+
noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8)
|
| 205 |
+
return noisy_patch, clean_patch
|
| 206 |
+
|
| 207 |
+
def set_sigma(self, sigma):
|
| 208 |
+
self.sigma = sigma
|
| 209 |
+
|
| 210 |
+
def __getitem__(self, clean_id):
|
| 211 |
+
clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16)
|
| 212 |
+
clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
|
| 213 |
+
|
| 214 |
+
noisy_img, _ = self._add_gaussian_noise(clean_img)
|
| 215 |
+
clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img)
|
| 216 |
+
|
| 217 |
+
return [clean_name], noisy_img, clean_img
|
| 218 |
+
|
| 219 |
+
def __len__(self):
|
| 220 |
+
return self.num_clean
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class DerainDehazeDataset(Dataset):
|
| 224 |
+
def __init__(self, args, task="derain"):
|
| 225 |
+
super(DerainDehazeDataset, self).__init__()
|
| 226 |
+
self.ids = []
|
| 227 |
+
self.task_idx = 0
|
| 228 |
+
self.args = args
|
| 229 |
+
self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
|
| 230 |
+
|
| 231 |
+
self.task_dict = {'derain': 0, 'dehaze': 1}
|
| 232 |
+
self.toTensor = ToTensor()
|
| 233 |
+
|
| 234 |
+
self.set_dataset(task)
|
| 235 |
+
|
| 236 |
+
def _init_input_ids(self):
|
| 237 |
+
if self.task_idx == 0:
|
| 238 |
+
self.ids = []
|
| 239 |
+
# 파일 목록 중 이미지 파일만 필터링
|
| 240 |
+
name_list = os.listdir(self.args.derain_path + 'input/')
|
| 241 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
| 242 |
+
self.ids += [self.args.derain_path + 'input/' + id_ for id_ in name_list]
|
| 243 |
+
elif self.task_idx == 1:
|
| 244 |
+
self.ids = []
|
| 245 |
+
# 파일 목록 중 이미지 파일만 필터링
|
| 246 |
+
name_list = os.listdir(self.args.dehaze_path + 'input/')
|
| 247 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
| 248 |
+
self.ids += [self.args.dehaze_path + 'input/' + id_ for id_ in name_list]
|
| 249 |
+
|
| 250 |
+
self.length = len(self.ids)
|
| 251 |
+
|
| 252 |
+
def _get_gt_path(self, degraded_name):
|
| 253 |
+
if self.task_idx == 0:
|
| 254 |
+
gt_name = '/'.join(degraded_name.replace("input", "target").split('/')[:-1] + degraded_name.replace("input", "target").replace("rain", "norain").split('/')[-1:])
|
| 255 |
+
print(gt_name)
|
| 256 |
+
elif self.task_idx == 1:
|
| 257 |
+
dir_name = degraded_name.split("input")[0] + 'target/'
|
| 258 |
+
name = degraded_name.split('/')[-1].split('_')[0] + '.png'
|
| 259 |
+
gt_name = dir_name + name
|
| 260 |
+
return gt_name
|
| 261 |
+
|
| 262 |
+
def set_dataset(self, task):
|
| 263 |
+
self.task_idx = self.task_dict[task]
|
| 264 |
+
self._init_input_ids()
|
| 265 |
+
|
| 266 |
+
def __getitem__(self, idx):
|
| 267 |
+
degraded_path = self.ids[idx]
|
| 268 |
+
clean_path = self._get_gt_path(degraded_path)
|
| 269 |
+
|
| 270 |
+
degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16)
|
| 271 |
+
clean_img = crop_img(np.array(Image.open(clean_path).convert('RGB')), base=16)
|
| 272 |
+
|
| 273 |
+
clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img)
|
| 274 |
+
degraded_name = degraded_path.split('/')[-1][:-4]
|
| 275 |
+
|
| 276 |
+
return [degraded_name], degraded_img, clean_img
|
| 277 |
+
|
| 278 |
+
def __len__(self):
|
| 279 |
+
return self.length
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class TestSpecificDataset(Dataset):
|
| 283 |
+
def __init__(self, args):
|
| 284 |
+
super(TestSpecificDataset, self).__init__()
|
| 285 |
+
self.args = args
|
| 286 |
+
self.degraded_ids = []
|
| 287 |
+
self._init_clean_ids(args.test_path)
|
| 288 |
+
|
| 289 |
+
self.toTensor = ToTensor()
|
| 290 |
+
|
| 291 |
+
def _init_clean_ids(self, root):
|
| 292 |
+
degraded_ids = []
|
| 293 |
+
# 파일 목록 중 이미지 파일만 필터링
|
| 294 |
+
name_list = os.listdir(root)
|
| 295 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
| 296 |
+
self.degraded_ids += [root + id_ for id_ in name_list]
|
| 297 |
+
|
| 298 |
+
self.num_img = len(self.degraded_ids)
|
| 299 |
+
|
| 300 |
+
def __getitem__(self, idx):
|
| 301 |
+
degraded_img = crop_img(np.array(Image.open(self.degraded_ids[idx]).convert('RGB')), base=16)
|
| 302 |
+
name = self.degraded_ids[idx].split('/')[-1][:-4]
|
| 303 |
+
|
| 304 |
+
degraded_img = self.toTensor(degraded_img)
|
| 305 |
+
|
| 306 |
+
return [name], degraded_img
|
| 307 |
+
|
| 308 |
+
def __len__(self):
|
| 309 |
+
return self.num_img
|