jiang20 commited on
Commit
1e29e93
·
1 Parent(s): 1e512e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -4
app.py CHANGED
@@ -11,12 +11,18 @@ import timm
11
  # model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
12
  # model.train()
13
 
14
- model = BadNet(3, 10)
15
 
 
 
 
 
16
 
17
 
18
  import os
19
 
 
 
20
  def print_bn():
21
  bn_data = []
22
  for m in model.modules():
@@ -27,7 +33,15 @@ def print_bn():
27
  bn_data.append(m.momentum)
28
  return bn_data
29
 
30
- def greet(image):
 
 
 
 
 
 
 
 
31
  # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
32
  # html = request_url(url)
33
  # key = os.getenv("OPENAI_API_KEY")
@@ -68,6 +82,9 @@ def greet(image):
68
  return out
69
 
70
 
71
- image = gr.inputs.Image(label="Upload a photo", shape=(32,32))
72
- iface = gr.Interface(fn=greet, inputs=image, outputs="text")
 
 
 
73
  iface.launch()
 
11
  # model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
12
  # model.train()
13
 
14
+ # model = BadNet(3, 10)
15
 
16
+ from diffusers import DiffusionPipeline
17
+
18
+ pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
19
+ # pipeline = pipeline.to('cuda:0')
20
 
21
 
22
  import os
23
 
24
+
25
+
26
  def print_bn():
27
  bn_data = []
28
  for m in model.modules():
 
33
  bn_data.append(m.momentum)
34
  return bn_data
35
 
36
+ def greet(text):
37
+ if(text is None):
38
+ pipeline.unet.load_attn_procs('./models/pytorch_lora_weights.bin')
39
+ else:
40
+ images = pipeline(text).images
41
+ image = images[0]
42
+ return image
43
+
44
+ def greet_old(image):
45
  # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
46
  # html = request_url(url)
47
  # key = os.getenv("OPENAI_API_KEY")
 
82
  return out
83
 
84
 
85
+
86
+ iface = gr.Interface(fn=greet, inputs='text', output="image")
87
+
88
+ # image = gr.inputs.Image(label="Upload a photo", shape=(32,32))
89
+ # iface = gr.Interface(fn=greet, inputs=image, outputs="text")
90
  iface.launch()