File size: 8,324 Bytes
ebb9992
 
 
 
 
 
 
 
 
 
 
3eb1ce9
 
 
 
 
 
ebb9992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import sys
from pathlib import Path
from typing import List, Optional

import gradio as gr
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from huggingface_hub import snapshot_download
from transformers import CLIPTokenizer

import constants
from checkpoint_handler import CheckpointHandler
from models.neti_clip_text_encoder import NeTICLIPTextModel
from models.xti_attention_processor import XTIAttenProc
from prompt_manager import PromptManager
from scripts.inference import run_inference

sys.path.append(".")
sys.path.append("..")

DESCRIPTION = '''
# A Neural Space-Time Representation for Text-to-Image Personalization
<p style="text-align: center;">
    This is a demo for our <a href="https://arxiv.org/abs/2305.15391">paper</a>: ''A Neural Space-Time Representation 
    for Text-to-Image Personalization''.
    <br>
    Project page and code is available <a href="https://neuraltextualinversion.github.io/NeTI/">here</a>.
    <br>
    We introduce a new text-conditioning latent space P* that is dependent on both the denoising process timestep and 
    the U-Net layers.
    This space-time representation is learned implicitly via a small mapping network.
    <br>
    Here, you can generate images using one of the concepts trained in our paper. Simply select your concept and 
    random seed.
    <br>
    You can also choose different truncation values to play with the reconstruction vs. editability of the concept.
</p>
'''

CONCEPT_TO_PLACEHOLDER = {
    'barn': '<barn>',
    'cat': '<cat>',
    'clock': '<clock>',
    'colorful_teapot': '<colorful-teapot>',
    'dangling_child': '<dangling-child>',
    'dog': '<dog>',
    'elephant': '<elephant>',
    'fat_stone_bird': '<stone-bird>',
    'headless_statue': '<headless-statue>',
    'lecun': '<lecun>',
    'maeve': '<maeve-dog>',
    'metal_bird': '<metal-bird>',
    'mugs_skulls': '<mug-skulls>',
    'rainbow_cat': '<rainbow-cat>',
    'red_bowl': '<red-bowl>',
    'teddybear': '<teddybear>',
    'tortoise_plushy': '<tortoise-plushy>',
    'wooden_pot': '<wooden-pot>'
}

MODELS_PATH = Path('./trained_models')
MODELS_PATH.mkdir(parents=True, exist_ok=True)


def load_stable_diffusion_model(pretrained_model_name_or_path: str,
                                num_denoising_steps: int = 50,
                                torch_dtype: torch.dtype = torch.float16) -> StableDiffusionPipeline:
    tokenizer = CLIPTokenizer.from_pretrained(
        pretrained_model_name_or_path, subfolder="tokenizer")
    text_encoder = NeTICLIPTextModel.from_pretrained(
        pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
    )
    pipeline = StableDiffusionPipeline.from_pretrained(
        pretrained_model_name_or_path,
        torch_dtype=torch_dtype,
        text_encoder=text_encoder,
        tokenizer=tokenizer
    ).to("cuda")
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
    pipeline.unet.set_attn_processor(XTIAttenProc())
    return pipeline


def get_possible_concepts() -> List[str]:
    objects = [x for x in MODELS_PATH.iterdir() if x.is_dir()]
    return [x.name for x in objects]


def load_sd_and_all_tokens():
    mappers = {}
    pipeline = load_stable_diffusion_model(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4")
    print("Downloading all models from HF Hub...")
    snapshot_download(repo_id="neural-ti/NeTI", local_dir='./trained_models')
    print("Done.")
    concepts = get_possible_concepts()
    for concept in concepts:
        print(f"Loading model for concept: {concept}")
        learned_embeds_path = MODELS_PATH / concept / f"{concept}-learned_embeds.bin"
        mapper_path = MODELS_PATH / concept / f"{concept}-mapper.pt"
        train_cfg, mapper = CheckpointHandler.load_mapper(mapper_path=mapper_path)
        placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
            learned_embeds_path=learned_embeds_path,
            text_encoder=pipeline.text_encoder,
            tokenizer=pipeline.tokenizer
        )
        mappers[concept] = {
            "mapper": mapper,
            "placeholder_token": placeholder_token,
            "placeholder_token_id": placeholder_token_id
        }
    return mappers, pipeline


mappers, pipeline = load_sd_and_all_tokens()


def main_pipeline(concept_name: str,
                  prompt_input: str,
                  seed: int,
                  use_truncation: bool = False,
                  truncation_idx: Optional[int] = None) -> Image.Image:
    pipeline.text_encoder.text_model.embeddings.set_mapper(mappers[concept_name]["mapper"])
    placeholder_token = mappers[concept_name]["placeholder_token"]
    placeholder_token_id = mappers[concept_name]["placeholder_token_id"]
    prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
                                   text_encoder=pipeline.text_encoder,
                                   timesteps=pipeline.scheduler.timesteps,
                                   unet_layers=constants.UNET_LAYERS,
                                   placeholder_token=placeholder_token,
                                   placeholder_token_id=placeholder_token_id,
                                   torch_dtype=torch.float16)
    image = run_inference(prompt=prompt_input.replace("*", CONCEPT_TO_PLACEHOLDER[concept_name]),
                          pipeline=pipeline,
                          prompt_manager=prompt_manager,
                          seeds=[int(seed)],
                          num_images_per_prompt=1,
                          truncation_idx=truncation_idx if use_truncation else None)
    return [image]


with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)

    gr.HTML('''<a href="https://huggingface.co/spaces/neural-ti/NeTI?duplicate=true"><img src="https://bit.ly/3gLdBN6" 
            alt="Duplicate Space"></a>''')

    with gr.Row():
        with gr.Column():
            concept = gr.Dropdown(get_possible_concepts(), multiselect=False, label="Concept",
                                  info="Choose your concept")
            prompt = gr.Textbox(label="Input prompt", info="Input prompt with placeholder for concept. "
                                                           "Please use * to specify the concept.")
            random_seed = gr.Number(value=42, label="Random seed", precision=0)
            use_truncation = gr.Checkbox(label="Use inference-time dropout",
                                         info="Whether to use our dropout technique when computing the concept "
                                              "embeddings.")
            truncation_idx = gr.Slider(8, 128, label="Truncation index",
                                       info="If using truncation, which index to truncate from. Lower numbers tend to "
                                            "result in more editable images, but at the cost of reconstruction.")
            run_button = gr.Button('Generate')

        with gr.Column():
            result = gr.Gallery(label='Result')
            inputs = [concept, prompt, random_seed, use_truncation, truncation_idx]
            outputs = [result]
            run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)

    with gr.Row():
        examples = [
            ["maeve", "A photo of * swimming in the ocean", 5196, True, 16],
            ["dangling_child", "A photo of * in Times Square", 3552126062741487430, False, 8],
            ["teddybear", "A photo of * at his graduation ceremony after finishing his PhD", 263, True, 32],
            ["red_bowl", "A * vase filled with flowers", 13491504810502930872, False, 8],
            ["metal_bird", "* in a comic book", 1028, True, 24],
            ["fat_stone_bird", "A movie poster of The Rock, featuring * about on Godzilla", 7393181316156044422, True,
             64],
        ]
        gr.Examples(examples=examples,
                    inputs=[concept, prompt, random_seed, use_truncation, truncation_idx],
                    outputs=[result],
                    fn=main_pipeline,
                    cache_examples=True)

demo.queue(max_size=50).launch(share=False)