|
import gradio as gr |
|
import argparse |
|
import os |
|
|
|
import pandas as pd |
|
from PIL import Image |
|
import numpy as np |
|
import torch as th |
|
from torchvision import transforms |
|
|
|
import diffusers |
|
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, LCMScheduler |
|
import gc |
|
from safetensors import safe_open |
|
|
|
from models import SAR2OptUNetv3 |
|
from utils import update_args_from_yaml, safe_load |
|
|
|
transform_sar = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize((256, 256)), |
|
transforms.Normalize((0.5), (0.5)), |
|
]) |
|
AVAILABLE_MODELS = { |
|
"Sen12:LCM-Model": "models/model.safetensors", |
|
"Sen12:Org-Model": "models/model_org.safetensors", |
|
} |
|
|
|
device = th.device('cuda:0' if th.cuda.is_available() else 'cpu') |
|
|
|
def safe_load(model_path): |
|
assert "safetensors" in model_path |
|
state_dict = {} |
|
with safe_open(model_path, framework="pt", device="cpu") as f: |
|
for k in f.keys(): |
|
state_dict[k] = f.get_tensor(k) |
|
return state_dict |
|
|
|
unet_model = SAR2OptUNetv3( |
|
sample_size=256, |
|
in_channels=4, |
|
out_channels=3, |
|
layers_per_block=2, |
|
block_out_channels=(128, 128, 256, 256, 512, 512), |
|
down_block_types=( |
|
"DownBlock2D", |
|
"DownBlock2D", |
|
"DownBlock2D", |
|
"DownBlock2D", |
|
"AttnDownBlock2D", |
|
"DownBlock2D", |
|
), |
|
up_block_types=( |
|
"UpBlock2D", |
|
"AttnUpBlock2D", |
|
"UpBlock2D", |
|
"UpBlock2D", |
|
"UpBlock2D", |
|
"UpBlock2D", |
|
), |
|
) |
|
|
|
print('load unet safetensos done!') |
|
lcm_scheduler = LCMScheduler(num_train_timesteps=1000) |
|
|
|
unet_model.to(device) |
|
unet_model.eval() |
|
|
|
model_kwargs = {} |
|
|
|
|
|
def predict(condition, nums_step, model_name): |
|
unet_checkpoint = AVAILABLE_MODELS[model_name] |
|
unet_model.load_state_dict(safe_load(unet_checkpoint), strict=True) |
|
unet_model.eval().to(device) |
|
with th.no_grad(): |
|
lcm_scheduler.set_timesteps(nums_step, device=device) |
|
timesteps = lcm_scheduler.timesteps |
|
pred_latent = th.randn(size=[1, 3, 256, 256], device=device) |
|
condition = condition.convert("L") |
|
condition = transform_sar(condition) |
|
condition = th.unsqueeze(condition, 0) |
|
condition = condition.to(device) |
|
for timestep in timesteps: |
|
latent_to_pred = th.cat((pred_latent, condition), dim=1) |
|
model_pred = unet_model(latent_to_pred, timestep) |
|
pred_latent, denoised = lcm_scheduler.step( |
|
model_output=model_pred, |
|
timestep=timestep, |
|
sample=pred_latent, |
|
return_dict=False) |
|
sample = denoised.cpu() |
|
|
|
sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) |
|
sample = sample.permute(0, 2, 3, 1) |
|
sample = sample.contiguous() |
|
sample = sample.cpu().numpy() |
|
sample = sample.squeeze(0) |
|
sample = Image.fromarray(sample) |
|
return sample |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=[gr.Image(type="pil"), |
|
gr.Slider(1, 1000), |
|
gr.Dropdown( |
|
choices=list(AVAILABLE_MODELS.keys()), |
|
value=list(AVAILABLE_MODELS.keys())[0], |
|
label="Choose the Model"),], |
|
|
|
outputs=gr.Image(type="pil"), |
|
examples=[ |
|
[os.path.join(os.path.dirname(__file__), "sar_1.png"), 8, "Sen12:LCM-Model"], |
|
[os.path.join(os.path.dirname(__file__), "sar_2.png"), 16, "Sen12:LCM-Model"], |
|
[os.path.join(os.path.dirname(__file__), "sar_3.png"), 500, "Sen12:Org-Model"], |
|
[os.path.join(os.path.dirname(__file__), "sar_4.png"), 1000, "Sen12:Org-Model"], |
|
], |
|
title="SAR to Optical Image🚀", |
|
description=""" |
|
# 🎯 Instruction |
|
This is a project that converts SAR images into optical images, based on conditional diffusion. |
|
|
|
Input a SAR image, and its corresponding optical image will be obtained. |
|
|
|
## 📢 Inputs |
|
- `condition`: the SAR image that you want to transfer. |
|
- `timestep_respacing`: the number of iteration steps when inference. |
|
|
|
## 🎉 Outputs |
|
- The corresponding optical image. |
|
|
|
**Paper** : [Accelerating Diffusion for SAR-to-Optical Image Translation via Adversarial Consistency Distillation](https://arxiv.org/abs/2407.06095) |
|
|
|
**Github** : https://github.com/Coordi777/Accelerating-Diffusion-for-SAR-to-Optical-Image-Translation |
|
""" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|