Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from torchvision import transforms | |
from diffusers import StableDiffusionPipeline | |
from model import ResNet, ResidualBlock | |
from attack import Attack | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1-base" | |
) | |
pipe = pipe.to(device) | |
CLASSES = ( | |
"plane", | |
"car", | |
"bird", | |
"cat", | |
"deer", | |
"dog", | |
"frog", | |
"horse", | |
"ship", | |
"truck", | |
) | |
def load_classifer(model_path): | |
# load resnet model | |
model = ResNet(ResidualBlock, [2, 2, 2]) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
return model | |
classifer = load_classifer("./models/resnet.ckpt") | |
attack = Attack(pipe, classifer, device) | |
def classifer_pred(image): | |
to_pil = transforms.ToPILImage() | |
input = attack.transform(to_pil(image[0])) | |
outputs = classifer(input) | |
_, predicted = torch.max(outputs, 1) | |
return CLASSES[predicted[0]] | |
def run_attack(prompt, epsilon): | |
image, perturbed_image = attack(prompt, epsilon=epsilon) | |
pred = classifer_pred(perturbed_image) | |
return image, pred | |
demo = gr.Interface( | |
run_attack, | |
[gr.Text(), gr.Slider(minimum=0.0, maximum=0.3, value=float)], | |
[gr.Image(), gr.Text()], | |
title="Stable Diffused Adversarial Attacks", | |
) | |
demo.launch() | |