File size: 5,855 Bytes
9624440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Stylegan-nada-ailanta.ipynb
Automatically generated by Colab.
Original file is located at
    https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE
# Проект "CLIP-Guided Domain Adaptation of Image Generators"
Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946).
Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя:
- Сдвиг генератора по текстовому промпту
- Генерация примеров
- Генерация примеров из готовых пресетов
- Веб-демо
- Стилизация изображения из файла
## 1. Установка
"""

# @title
# Импорт нужных библиотек
import os
import sys
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Настройка устройства
device = "cuda" if torch.cuda.is_available() else "cpu"

import os
import subprocess

if not os.path.exists("stylegan2-pytorch"):
    subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"])
os.chdir("stylegan2-pytorch")

import gdown
gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT')
gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL')
sys.path.append(os.path.abspath("stylegan2-pytorch"))
from model import Generator

# Параметры генератора
latent_dim = 512
f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device)
f_generator.load_state_dict(state_dict['g_ema'])
f_generator.eval()

g_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
g_generator.load_state_dict(state_dict['g_ema'])

# Загрузка модели CLIP
import clip
clip_model, preprocess = clip.load("ViT-B/32", device=device)

latent_dim=512
batch_size=4

"""## 6. Готовые пресеты"""

# @title Загрузка пресетов
os.makedirs("/content/presets", exist_ok=True)

gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth')
gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth')
gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth')

# @title Генерация примеров из пресета
# Загрузка генератора из файла
def load_model(file_path, latent_dim=512, size=1024):

    state_dicts = torch.load(file_path, map_location=device)

    # Инициализация
    trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device)

    # Загрузка весов
    trained_generator.load_state_dict(state_dicts)

    trained_generator.eval()

    return trained_generator

model_paths = {
    "Photo -> Pencil Sketch": "/content/presets/sketch.pth",
    "Photo -> Modigliani Painting": "/content/presets/modigliani.pth",
    "Human -> Werewolf": "/content/presets/werewolf.pth"
}


"""## 8. Веб-демо"""

import gradio as gr

def get_avg_image(net):
    avg_image = net(net.latent_avg.unsqueeze(0),
                    input_code=True,
                    randomize_noise=False,
                    return_latents=False,
                    average_code=True)[0]
    avg_image = avg_image.to('cuda').float().detach()
    return avg_image

# Функция обработки изображения
def process_image(image):
    # Конвертация в объект PIL
    image = Image.fromarray(image)
    
    # Изменение размера до 256x256
    image = image.resize((256, 256))

    input_image = transform(image).unsqueeze(0).to(device)

    opts.n_iters_per_batch = 5
    opts.resize_outputs = False  # generate outputs at full resolution
    
    from restyle.utils.inference_utils import run_on_batch
    
    with torch.no_grad():
        avg_image = get_avg_image(restyle_net)
        result_batch, result_latents = run_on_batch(input_image, restyle_net, opts, avg_image)
    
    inverted_latent = torch.Tensor(result_latents[0][4]).cuda().unsqueeze(0).unsqueeze(1)
    
    with torch.no_grad():
        sampled_src = f_generator(inverted_latent, input_is_latent=True)[0]
        frozen_image = (sampled_src.clamp(-1, 1) + 1) / 2.0  # Нормализация к [0, 1]
        frozen_image = frozen_image.permute(0, 2, 3, 1).cpu().numpy()
    
        g_generator.eval()
    
        sampled_src = g_generator(inverted_latent, input_is_latent=True)[0]
        trained_image = (sampled_src.clamp(-1, 1) + 1) / 2.0  # Нормализация к [0, 1]
        trained_image = trained_image.permute(0, 2, 3, 1).cpu().numpy()
    images = []
    images.append(image)
    images.append(frozen_image.squeeze(0))
    images.append(trained_image.squeeze(0))
    return images

# Интерфейс Gradio
iface = gr.Interface(
    fn=process_image,  # Функция обработки
    inputs=gr.Image(type="numpy"),  # Поле для загрузки изображения
    outputs=gr.Gallery(label="Результаты генерации", columns=2),
    title="Обработка изображения",
    description="Загрузите изображение"
)

iface.launch()