from argparse import Namespace from glob import glob import yaml import os import gradio as gr import torch import torchvision import safetensors from diffusers import AutoencoderKL from peft import get_peft_model, LoraConfig, set_peft_model_state_dict from huggingface_hub import snapshot_download pretrained_model_path = snapshot_download(repo_id="revp2024/revp-censorship") with open(glob(os.path.join(pretrained_model_path, 'hparams.yml'), recursive=True)[0]) as f: args = Namespace(**yaml.safe_load(f)) def prepare_model(): print('Loading model ...') vae_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["conv", "conv1", "conv2", "to_q", "to_k", "to_v", "to_out.0"], ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae" ) vae = get_peft_model(vae, vae_lora_config) lora_weights_path = os.path.join(pretrained_model_path, f"pytorch_lora_weights.safetensors") state_dict = {} with safetensors.torch.safe_open(lora_weights_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) set_peft_model_state_dict(vae, state_dict) print('Done.') return vae.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') @torch.no_grad() def add_censorship(input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size): background, layers, _ = input_image.values() input_images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 H, W = input_images.shape[-2:] if H > 1024 or W > 1024: H_t, W_t = H, W if H > W: H, W = 1024, int(1024 * W_t / H_t) else: H, W = int(1024 * H_t / W_t), 1024 H_q8 = (H // 8) * 8 W_q8 = (W // 8) * 8 input_images = torch.nn.functional.interpolate(input_images, (H_q8, W_q8), mode='bilinear') mask = torch.nn.functional.interpolate(mask, (H_q8, W_q8)) if soft_edges: mask = torchvision.transforms.functional.gaussian_blur(mask, soft_edge_kernel_size)[0][0] input_images = input_images.to(vae.device) if mode == 'Pixelation': censored = torch.nn.functional.avg_pool2d( input_images, pixelation_block_size) censored = torch.nn.functional.interpolate(censored, input_images.shape[-2:]) elif mode == 'Gaussian blur': censored = torchvision.transforms.functional.gaussian_blur( input_images, blur_kernel_size) elif mode == 'Black': censored = torch.zeros_like(input_images) else: raise ValueError("censor_mode has to be either `pixelation' or `gaussian_blur'") mask = mask.to(input_images.device) censored_images = input_images * (1 - mask) + censored * mask censored_images *= 255 input_images = input_images * 2 - 1 with vae.disable_adapter(): latents = vae.encode(input_images).latent_dist.mean images = vae.decode(latents, return_dict=False)[0] # denormalize images = images / 2 + 0.5 images *= 255 residuals = (images - censored_images).clamp(-args.budget, args.budget) images = (censored_images + residuals).clamp(0, 255).to(torch.uint8) gr.Info("Try to donwload/copy the censored image to the `Remove censorsip' tab") return images[0].permute(1, 2, 0).cpu().numpy() @torch.no_grad() def remove_censorship(input_image, x1, y1, x2, y2): background, layers, _ = input_image.values() images = torch.from_numpy(background).permute(2, 0, 1)[None, :3] / 255 mask = torch.from_numpy(layers[0]).permute(2, 0, 1)[None, -1:] / 255 images = images * (1 - mask) images = images[..., y1:y2, x1:x2] latents = vae.encode((images * 2 - 1).to(vae.device)).latent_dist.mean with vae.disable_adapter(): images = vae.decode(latents, return_dict=False)[0] # denormalize images = images / 2 + 0.5 images *= 255 images = images.clamp(0, 255).to(torch.uint8) return images[0].permute(1, 2, 0).cpu().numpy() # @@@@@@@ Start of the program @@@@@@@@ vae = prepare_model() css = ''' .my-disabled { background-color: #eee; } .my-disabled input { background-color: #eee; } ''' with gr.Blocks(css=css) as demo: gr.Markdown('# ReVP: Reversible Visual Processing with Latent Models') with gr.Tab('Add censorship'): with gr.Row(): with gr.Column(): input_image = gr.ImageEditor(brush=gr.Brush(default_size=100)) with gr.Accordion('Options', open=False) as options_accord: mode = gr.Radio(label='Mode', choices=['Pixelation', 'Gaussian blur', 'Black'], value='Pixelation', interactive=True) pixelation_block_size = gr.Slider(label='Block size', minimum=10, maximum=40, value=25, step=1, interactive=True) blur_kernel_size = gr.Slider(label='Blur kernel size', minimum=21, maximum=151, value=85, step=2, interactive=True, visible=False) def change_mode(mode): if mode == 'Gaussian blur': return gr.Slider(visible=False), gr.Slider(visible=True), gr.Accordion(open=True) elif mode == 'Pixelation': return gr.Slider(visible=True), gr.Slider(visible=False), gr.Accordion(open=True) elif mode == 'Black': return gr.Slider(visible=False), gr.Slider(visible=False), gr.Accordion(open=True) else: raise NotImplementedError mode.select(change_mode, mode, [pixelation_block_size, blur_kernel_size, options_accord]) with gr.Row(variant='panel'): soft_edges = gr.Checkbox(label='Soft edges', value=True, interactive=True, scale=1) soft_edge_kernel_size = gr.Slider(label='Soft edge kernel size', minimum=21, maximum=49, value=35, step=2, interactive=True, visible=True, scale=2) def change_soft_edges(soft_edges): return gr.Slider(visible=True if soft_edges else False), gr.Accordion(open=True) soft_edges.change(change_soft_edges, soft_edges, [soft_edge_kernel_size, options_accord]) submit_btn = gr.Button('Submit') output_image = gr.Image(label='Censored', show_download_button=True) submit_btn.click( fn=add_censorship, inputs=[input_image, mode, pixelation_block_size, blur_kernel_size, soft_edges, soft_edge_kernel_size], outputs=output_image ) with gr.Tab('Remove censorship'): with gr.Row(): with gr.Column(): input_image = gr.ImageEditor() with gr.Accordion('Manual cropping', open=False): with gr.Row(): with gr.Row(): x1 = gr.Number(value=0, label='x1') y1 = gr.Number(value=0, label='y1') with gr.Row(): x2_ = gr.Number(value=10000, label='x2', interactive=False, elem_classes='my-disabled') y1_ = gr.Number(value=0, label='y1', interactive=False, elem_classes='my-disabled') with gr.Row(): with gr.Row(): x1_ =gr.Number(value=0, label='x1', elem_classes='my-disabled') y2_ = gr.Number(value=10000, label='y2', elem_classes='my-disabled') with gr.Row(): x2 = gr.Number(value=10000, label='x2') y2 = gr.Number(value=10000, label='y2') submit_btn = gr.Button('Submit') output_image = gr.Image(label='Uncensored') submit_btn.click( fn=remove_censorship, inputs=[input_image, x1, y1, x2, y2], outputs=output_image ) # sync coordinate on changed x1.change(lambda x : x, x1, x1_) x2.change(lambda x : x, x2, x2_) y1.change(lambda x : x, y1, y1_) y2.change(lambda x : x, y2, y2_) if __name__ == '__main__': demo.queue(4) demo.launch()