Spaces:
Runtime error
Runtime error
"""Stylegan-nada-ailanta.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE | |
# Проект "CLIP-Guided Domain Adaptation of Image Generators" | |
Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946). | |
Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя: | |
- Сдвиг генератора по текстовому промпту | |
- Генерация примеров | |
- Генерация примеров из готовых пресетов | |
- Веб-демо | |
- Стилизация изображения из файла | |
## 1. Установка | |
""" | |
# @title | |
# Импорт нужных библиотек | |
import os | |
import sys | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import transforms | |
from torchvision.utils import save_image | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Настройка устройства | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
import os | |
import subprocess | |
if not os.path.exists("stylegan2-pytorch"): | |
subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"]) | |
os.chdir("stylegan2-pytorch") | |
import gdown | |
gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT') | |
gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL') | |
sys.path.append(os.path.abspath("stylegan2-pytorch")) | |
from model import Generator | |
# Параметры генератора | |
latent_dim = 512 | |
f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device) | |
state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device) | |
f_generator.load_state_dict(state_dict['g_ema']) | |
f_generator.eval() | |
g_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device) | |
g_generator.load_state_dict(state_dict['g_ema']) | |
# Загрузка модели CLIP | |
import clip | |
clip_model, preprocess = clip.load("ViT-B/32", device=device) | |
latent_dim=512 | |
batch_size=4 | |
"""## 6. Готовые пресеты""" | |
# @title Загрузка пресетов | |
os.makedirs("/content/presets", exist_ok=True) | |
gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth') | |
gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth') | |
gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth') | |
# @title Генерация примеров из пресета | |
# Загрузка генератора из файла | |
def load_model(file_path, latent_dim=512, size=1024): | |
state_dicts = torch.load(file_path, map_location=device) | |
# Инициализация | |
trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device) | |
# Загрузка весов | |
trained_generator.load_state_dict(state_dicts) | |
trained_generator.eval() | |
return trained_generator | |
model_paths = { | |
"Photo -> Pencil Sketch": "/content/presets/sketch.pth", | |
"Photo -> Modigliani Painting": "/content/presets/modigliani.pth", | |
"Human -> Werewolf": "/content/presets/werewolf.pth" | |
} | |
"""## 8. Веб-демо""" | |
import gradio as gr | |
def get_avg_image(net): | |
avg_image = net(net.latent_avg.unsqueeze(0), | |
input_code=True, | |
randomize_noise=False, | |
return_latents=False, | |
average_code=True)[0] | |
avg_image = avg_image.to('cuda').float().detach() | |
return avg_image | |
# Функция обработки изображения | |
def process_image(image): | |
# Конвертация в объект PIL | |
image = Image.fromarray(image) | |
# Изменение размера до 256x256 | |
image = image.resize((256, 256)) | |
input_image = transform(image).unsqueeze(0).to(device) | |
opts.n_iters_per_batch = 5 | |
opts.resize_outputs = False # generate outputs at full resolution | |
from restyle.utils.inference_utils import run_on_batch | |
with torch.no_grad(): | |
avg_image = get_avg_image(restyle_net) | |
result_batch, result_latents = run_on_batch(input_image, restyle_net, opts, avg_image) | |
inverted_latent = torch.Tensor(result_latents[0][4]).cuda().unsqueeze(0).unsqueeze(1) | |
with torch.no_grad(): | |
sampled_src = f_generator(inverted_latent, input_is_latent=True)[0] | |
frozen_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] | |
frozen_image = frozen_image.permute(0, 2, 3, 1).cpu().numpy() | |
g_generator.eval() | |
sampled_src = g_generator(inverted_latent, input_is_latent=True)[0] | |
trained_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] | |
trained_image = trained_image.permute(0, 2, 3, 1).cpu().numpy() | |
images = [] | |
images.append(image) | |
images.append(frozen_image.squeeze(0)) | |
images.append(trained_image.squeeze(0)) | |
return images | |
# Интерфейс Gradio | |
iface = gr.Interface( | |
fn=process_image, # Функция обработки | |
inputs=gr.Image(type="numpy"), # Поле для загрузки изображения | |
outputs=gr.Gallery(label="Результаты генерации", columns=2), | |
title="Обработка изображения", | |
description="Загрузите изображение" | |
) | |
iface.launch() | |