import gradio as gr import argparse import os import pandas as pd from PIL import Image import numpy as np import torch as th from torchvision import transforms import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, LCMScheduler import gc from safetensors import safe_open from models import SAR2OptUNetv3 from utils import update_args_from_yaml, safe_load transform_sar = transforms.Compose([ transforms.ToTensor(), transforms.Resize((256, 256)), transforms.Normalize((0.5), (0.5)), ]) AVAILABLE_MODELS = { "Sen12:LCM-Model": "models/model.safetensors", "Sen12:Org-Model": "models/model_org.safetensors", } device = th.device('cuda:0' if th.cuda.is_available() else 'cpu') def safe_load(model_path): assert "safetensors" in model_path state_dict = {} with safe_open(model_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) return state_dict unet_model = SAR2OptUNetv3( sample_size=256, in_channels=4, out_channels=3, layers_per_block=2, block_out_channels=(128, 128, 256, 256, 512, 512), down_block_types=( "DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D", ), up_block_types=( "UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", ), ) print('load unet safetensos done!') lcm_scheduler = LCMScheduler(num_train_timesteps=1000) unet_model.to(device) unet_model.eval() model_kwargs = {} def predict(condition, nums_step, model_name): unet_checkpoint = AVAILABLE_MODELS[model_name] unet_model.load_state_dict(safe_load(unet_checkpoint), strict=True) unet_model.eval().to(device) with th.no_grad(): lcm_scheduler.set_timesteps(nums_step, device=device) timesteps = lcm_scheduler.timesteps pred_latent = th.randn(size=[1, 3, 256, 256], device=device) condition = condition.convert("L") condition = transform_sar(condition) condition = th.unsqueeze(condition, 0) condition = condition.to(device) for timestep in timesteps: latent_to_pred = th.cat((pred_latent, condition), dim=1) model_pred = unet_model(latent_to_pred, timestep) pred_latent, denoised = lcm_scheduler.step( model_output=model_pred, timestep=timestep, sample=pred_latent, return_dict=False) sample = denoised.cpu() sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous() sample = sample.cpu().numpy() sample = sample.squeeze(0) sample = Image.fromarray(sample) return sample demo = gr.Interface( fn=predict, inputs=[gr.Image(type="pil"), gr.Slider(1, 1000), gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value=list(AVAILABLE_MODELS.keys())[0], label="Choose the Model"),], # gr.Radio(["Sent", "GF3"], label="Model", info="Which model to you want to use?"), ], outputs=gr.Image(type="pil"), examples=[ [os.path.join(os.path.dirname(__file__), "sar_1.png"), 8, "Sen12:LCM-Model"], [os.path.join(os.path.dirname(__file__), "sar_2.png"), 16, "Sen12:LCM-Model"], [os.path.join(os.path.dirname(__file__), "sar_3.png"), 500, "Sen12:Org-Model"], [os.path.join(os.path.dirname(__file__), "sar_4.png"), 1000, "Sen12:Org-Model"], ], title="SAR to Optical Image🚀", description=""" # 🎯 Instruction This is a project that converts SAR images into optical images, based on conditional diffusion. Input a SAR image, and its corresponding optical image will be obtained. ## 📢 Inputs - `condition`: the SAR image that you want to transfer. - `timestep_respacing`: the number of iteration steps when inference. ## 🎉 Outputs - The corresponding optical image. **Paper** : [Accelerating Diffusion for SAR-to-Optical Image Translation via Adversarial Consistency Distillation](https://arxiv.org/abs/2407.06095) **Github** : https://github.com/Coordi777/Accelerating-Diffusion-for-SAR-to-Optical-Image-Translation """ ) if __name__ == "__main__": demo.launch()