ailanta's picture
Create app.py
9624440 verified
"""Stylegan-nada-ailanta.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE
# Проект "CLIP-Guided Domain Adaptation of Image Generators"
Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946).
Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя:
- Сдвиг генератора по текстовому промпту
- Генерация примеров
- Генерация примеров из готовых пресетов
- Веб-демо
- Стилизация изображения из файла
## 1. Установка
"""
# @title
# Импорт нужных библиотек
import os
import sys
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
# Настройка устройства
device = "cuda" if torch.cuda.is_available() else "cpu"
import os
import subprocess
if not os.path.exists("stylegan2-pytorch"):
subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"])
os.chdir("stylegan2-pytorch")
import gdown
gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT')
gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL')
sys.path.append(os.path.abspath("stylegan2-pytorch"))
from model import Generator
# Параметры генератора
latent_dim = 512
f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device)
f_generator.load_state_dict(state_dict['g_ema'])
f_generator.eval()
g_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
g_generator.load_state_dict(state_dict['g_ema'])
# Загрузка модели CLIP
import clip
clip_model, preprocess = clip.load("ViT-B/32", device=device)
latent_dim=512
batch_size=4
"""## 6. Готовые пресеты"""
# @title Загрузка пресетов
os.makedirs("/content/presets", exist_ok=True)
gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth')
gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth')
gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth')
# @title Генерация примеров из пресета
# Загрузка генератора из файла
def load_model(file_path, latent_dim=512, size=1024):
state_dicts = torch.load(file_path, map_location=device)
# Инициализация
trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device)
# Загрузка весов
trained_generator.load_state_dict(state_dicts)
trained_generator.eval()
return trained_generator
model_paths = {
"Photo -> Pencil Sketch": "/content/presets/sketch.pth",
"Photo -> Modigliani Painting": "/content/presets/modigliani.pth",
"Human -> Werewolf": "/content/presets/werewolf.pth"
}
"""## 8. Веб-демо"""
import gradio as gr
def get_avg_image(net):
avg_image = net(net.latent_avg.unsqueeze(0),
input_code=True,
randomize_noise=False,
return_latents=False,
average_code=True)[0]
avg_image = avg_image.to('cuda').float().detach()
return avg_image
# Функция обработки изображения
def process_image(image):
# Конвертация в объект PIL
image = Image.fromarray(image)
# Изменение размера до 256x256
image = image.resize((256, 256))
input_image = transform(image).unsqueeze(0).to(device)
opts.n_iters_per_batch = 5
opts.resize_outputs = False # generate outputs at full resolution
from restyle.utils.inference_utils import run_on_batch
with torch.no_grad():
avg_image = get_avg_image(restyle_net)
result_batch, result_latents = run_on_batch(input_image, restyle_net, opts, avg_image)
inverted_latent = torch.Tensor(result_latents[0][4]).cuda().unsqueeze(0).unsqueeze(1)
with torch.no_grad():
sampled_src = f_generator(inverted_latent, input_is_latent=True)[0]
frozen_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
frozen_image = frozen_image.permute(0, 2, 3, 1).cpu().numpy()
g_generator.eval()
sampled_src = g_generator(inverted_latent, input_is_latent=True)[0]
trained_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
trained_image = trained_image.permute(0, 2, 3, 1).cpu().numpy()
images = []
images.append(image)
images.append(frozen_image.squeeze(0))
images.append(trained_image.squeeze(0))
return images
# Интерфейс Gradio
iface = gr.Interface(
fn=process_image, # Функция обработки
inputs=gr.Image(type="numpy"), # Поле для загрузки изображения
outputs=gr.Gallery(label="Результаты генерации", columns=2),
title="Обработка изображения",
description="Загрузите изображение"
)
iface.launch()