Spaces:
Runtime error
Runtime error
import torch | |
from torchvision import transforms | |
class Attack: | |
def __init__(self, pipe, classifer, device="cpu"): | |
self.device = device | |
self.pipe = pipe | |
self.generator = torch.Generator(device=self.device).manual_seed(1024) | |
self.classifer = classifer | |
def __call__( | |
self, prompt, negative_prompt="", size=512, guidance_scale=8, epsilon=0 | |
): | |
pipe_output = self.pipe( | |
prompt=prompt, # What to generate | |
negative_prompt=negative_prompt, # What NOT to generate | |
height=size, | |
width=size, # Specify the image size | |
guidance_scale=guidance_scale, # How strongly to follow the prompt | |
num_inference_steps=30, # How many steps to take | |
generator=self.generator, # Fixed random seed | |
) | |
# Resulting image: | |
init_image = pipe_output.images[0] | |
image = self.transform(init_image) | |
image.requires_grad = True | |
outputs = self.classifer(image).to(self.device) | |
target = torch.tensor([0]).to(self.device) | |
return ( | |
init_image, | |
self.untargeted_attack(image, outputs, target, epsilon), | |
) | |
def transform(self, image): | |
img_tfms = transforms.Compose( | |
[transforms.Resize(32), transforms.ToTensor()] | |
) | |
image = img_tfms(image) | |
image = torch.unsqueeze(image, dim=0) | |
return image | |
def untargeted_attack(self, image, pred, target, epsilon): | |
loss = torch.nn.functional.nll_loss(pred, target) | |
self.classifer.zero_grad() | |
loss.backward() | |
gradient_sign = image.grad.data.sign() | |
perturbed_image = image + epsilon * gradient_sign | |
perturbed_image = torch.clamp(perturbed_image, 0, 1) | |
return perturbed_image | |