Pie31415's picture
update
cedb7e1
raw
history blame contribute delete
1.37 kB
import torch
import gradio as gr
from torchvision import transforms
from diffusers import StableDiffusionPipeline
from model import ResNet, ResidualBlock
from attack import Attack
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base"
)
pipe = pipe.to(device)
CLASSES = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
def load_classifer(model_path):
# load resnet model
model = ResNet(ResidualBlock, [2, 2, 2])
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
classifer = load_classifer("./models/resnet.ckpt")
attack = Attack(pipe, classifer, device)
def classifer_pred(image):
to_pil = transforms.ToPILImage()
input = attack.transform(to_pil(image[0]))
outputs = classifer(input)
_, predicted = torch.max(outputs, 1)
return CLASSES[predicted[0]]
def run_attack(prompt, epsilon):
image, perturbed_image = attack(prompt, epsilon=epsilon)
pred = classifer_pred(perturbed_image)
return image, pred
demo = gr.Interface(
run_attack,
[gr.Text(), gr.Slider(minimum=0.0, maximum=0.3, value=float)],
[gr.Image(), gr.Text()],
title="Stable Diffused Adversarial Attacks",
)
demo.launch()