"""Stylegan-nada-ailanta.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE # Проект "CLIP-Guided Domain Adaptation of Image Generators" Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946). Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя: - Сдвиг генератора по текстовому промпту - Генерация примеров - Генерация примеров из готовых пресетов - Веб-демо - Стилизация изображения из файла ## 1. Установка """ # @title # Импорт нужных библиотек import os import sys from tqdm import tqdm import torch import torch.nn as nn import torch.optim as optim from torchvision import transforms from torchvision.utils import save_image from PIL import Image import numpy as np import matplotlib.pyplot as plt # Настройка устройства device = "cuda" if torch.cuda.is_available() else "cpu" import os import subprocess if not os.path.exists("stylegan2-pytorch"): subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"]) os.chdir("stylegan2-pytorch") import gdown gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT') gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL') sys.path.append(os.path.abspath("stylegan2-pytorch")) from model import Generator # Параметры генератора latent_dim = 512 f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device) state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device) f_generator.load_state_dict(state_dict['g_ema']) f_generator.eval() g_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device) g_generator.load_state_dict(state_dict['g_ema']) # Загрузка модели CLIP import clip clip_model, preprocess = clip.load("ViT-B/32", device=device) latent_dim=512 batch_size=4 """## 6. Готовые пресеты""" # @title Загрузка пресетов os.makedirs("/content/presets", exist_ok=True) gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth') gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth') gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth') # @title Генерация примеров из пресета # Загрузка генератора из файла def load_model(file_path, latent_dim=512, size=1024): state_dicts = torch.load(file_path, map_location=device) # Инициализация trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device) # Загрузка весов trained_generator.load_state_dict(state_dicts) trained_generator.eval() return trained_generator model_paths = { "Photo -> Pencil Sketch": "/content/presets/sketch.pth", "Photo -> Modigliani Painting": "/content/presets/modigliani.pth", "Human -> Werewolf": "/content/presets/werewolf.pth" } """## 8. Веб-демо""" import gradio as gr def get_avg_image(net): avg_image = net(net.latent_avg.unsqueeze(0), input_code=True, randomize_noise=False, return_latents=False, average_code=True)[0] avg_image = avg_image.to('cuda').float().detach() return avg_image # Функция обработки изображения def process_image(image): # Конвертация в объект PIL image = Image.fromarray(image) # Изменение размера до 256x256 image = image.resize((256, 256)) input_image = transform(image).unsqueeze(0).to(device) opts.n_iters_per_batch = 5 opts.resize_outputs = False # generate outputs at full resolution from restyle.utils.inference_utils import run_on_batch with torch.no_grad(): avg_image = get_avg_image(restyle_net) result_batch, result_latents = run_on_batch(input_image, restyle_net, opts, avg_image) inverted_latent = torch.Tensor(result_latents[0][4]).cuda().unsqueeze(0).unsqueeze(1) with torch.no_grad(): sampled_src = f_generator(inverted_latent, input_is_latent=True)[0] frozen_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] frozen_image = frozen_image.permute(0, 2, 3, 1).cpu().numpy() g_generator.eval() sampled_src = g_generator(inverted_latent, input_is_latent=True)[0] trained_image = (sampled_src.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] trained_image = trained_image.permute(0, 2, 3, 1).cpu().numpy() images = [] images.append(image) images.append(frozen_image.squeeze(0)) images.append(trained_image.squeeze(0)) return images # Интерфейс Gradio iface = gr.Interface( fn=process_image, # Функция обработки inputs=gr.Image(type="numpy"), # Поле для загрузки изображения outputs=gr.Gallery(label="Результаты генерации", columns=2), title="Обработка изображения", description="Загрузите изображение" ) iface.launch()