happy_one / app.py
jiang20's picture
Update app.py
a0d63de
import gradio as gr
import requests
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
# from badnet_m import BadNet
# import timm
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
# model.train()
# model = BadNet(3, 10)
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
# pipeline = pipeline.to('cuda:0')
import os
def print_bn():
bn_data = []
for m in model.modules():
if(type(m) is nn.BatchNorm2d):
# print(m.momentum)
bn_data.extend(m.running_mean.data.numpy().tolist())
bn_data.extend(m.running_var.data.numpy().tolist())
bn_data.append(m.momentum)
return bn_data
def greet(text):
if(text == ''):
# return 'changing'
pipeline.unet.load_attn_procs('./pytorch_lora_weights.bin')
pipeline.safety_checker = lambda images, clip_input: (images, False)
# images = pipeline('a photo of dog').images
# image = images[0]
return None
else:
# return 'pipelining'
images = pipeline(text).images
image = images[0]
return image
def greet_backdoor(image):
# url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
# html = request_url(url)
# key = os.getenv("OPENAI_API_KEY")
# x = torch.ones([1,3,224,224])
if(image is None):
model.load_state_dict(torch.load("./badnet_cifar_all.pth"))
return 'change to backdoor'
# bn_data = print_bn()
# return ','.join([f'{x:.10f}' for x in bn_data])
else:
# print(type(image))
# print(image.min(), image.max())
# image = image[np.newaxis,:,:,:]
# print(image.shape)
# image = np.transpose(image,(0,3,1,2))
image = torch.tensor(image).float()
image = image/255.0
image = image.unsqueeze(0)
image = torch.permute(image, [0,3,1,2])
# transform_nor = transforms.Compose([transforms.ToTensor(), transforms.Resize((32,32)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
# image = transform_nor(image).unsqueeze(0)
out = model(image)
# model.train()
return out
iface = gr.Interface(fn=greet, inputs='text', outputs="image")
# iface = gr.Interface(fn=greet, inputs='text', outputs="text")
# image = gr.inputs.Image(label="Upload a photo", shape=(32,32))
# iface = gr.Interface(fn=greet, inputs=image, outputs="text")
iface.launch()