File size: 3,957 Bytes
1f93e83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
from PIL import Image
from DAI.pipeline_onestep import OneStepPipeline
from DAI.controlnetvae import ControlNetVAEModel
import numpy as np
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
    StableDiffusionPipeline
)
from transformers import CLIPTextModel, AutoTokenizer
from glob import glob
import json
import random
from diffusers.utils import make_image_grid, load_image
from peft import PeftModel
from peft import LoraConfig, get_peft_model
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict

from safetensors.torch import load_file


from DAI.pipeline_all import DAIPipeline
from DAI.decoder import CustomAutoencoderKL

from tqdm import tqdm
import argparse


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

weight_dtype = torch.float32
model_dir = "./weights"
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1"
revision = None
variant = None
# Load the model
# normal
controlnet = ControlNetVAEModel.from_pretrained(model_dir + "/controlnet", torch_dtype=weight_dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(model_dir + "/unet", torch_dtype=weight_dtype).to(device)
vae_2 = CustomAutoencoderKL.from_pretrained(model_dir + "/vae_2", torch_dtype=weight_dtype).to(device)


# Load other components of the pipeline
vae = AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant
    ).to(device)

# import pdb; pdb.set_trace()
text_encoder = CLIPTextModel.from_pretrained(
        pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, variant=variant
    ).to(device)
tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="tokenizer",
            revision=revision,
            use_fast=False,
        )
pipeline = DAIPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        controlnet=controlnet,
        safety_checker=None,
        scheduler=None,
        feature_extractor=None,
        t_start=0
    ).to(device)


# Create a directory to save the results
# Parse command line arguments
parser = argparse.ArgumentParser(description="Run reflection removal on images.")
parser.add_argument("--input_dir", type=str, required=True, help="Directory for evaluation inputs.")
parser.add_argument("--result_dir", type=str, required=True, help="Directory for evaluation results.")
parser.add_argument("--concat_dir", type=str, required=True, help="Directory for concat evaluation results.")

args = parser.parse_args()

input_dir = args.input_dir
result_dir = args.result_dir
concat_dir = args.concat_dir

os.makedirs(result_dir, exist_ok=True)
os.makedirs(concat_dir, exist_ok=True)

input_files = sorted(glob(os.path.join(input_dir, "*")))

for input_file in tqdm(input_files, desc="Processing images"):
    input_image = load_image(input_file)
    
    resolution = 0
    if max(input_image.size) < 768:
        resolution = None
    result_image = pipeline(
        image=torch.tensor(np.array(input_image)).permute(2, 0, 1).float().div(255).unsqueeze(0).to(device),
        prompt="remove glass reflection",
        vae_2=vae_2,
        processing_resolution=resolution
    ).prediction[0]

    result_image = (result_image + 1) / 2
    result_image = result_image.clip(0., 1.)
    result_image = result_image * 255
    result_image = result_image.astype(np.uint8)
    result_image = Image.fromarray(result_image)

    concat_image = make_image_grid([input_image, result_image], rows=1, cols=2)

    # Save the concatenated image
    input_filename = os.path.basename(input_file)
    concat_image.save(os.path.join(concat_dir, f"{input_filename}"))
    result_image.save(os.path.join(result_dir, f"{input_filename}"))