Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -27,18 +27,18 @@ def get_gpt2_pipeline():
|
|
27 |
# SD v1.4
|
28 |
def get_stable_diffusion_v14_pipeline():
|
29 |
model_id = "CompVis/stable-diffusion-v1-4"
|
30 |
-
|
31 |
# pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
|
32 |
-
|
33 |
torch.backends.cudnn.benchmark = True
|
34 |
-
return
|
35 |
|
36 |
# SD v1.5
|
37 |
def get_stable_diffusion_v15_pipeline():
|
38 |
model_id = "runwayml/stable-diffusion-v1-5"
|
39 |
-
|
40 |
-
|
41 |
-
return
|
42 |
|
43 |
def get_image(url):
|
44 |
response = requests.get(url)
|
@@ -50,9 +50,9 @@ def get_image(url):
|
|
50 |
def main():
|
51 |
prompt = "Hello world, I'm vizard,"
|
52 |
|
53 |
-
|
54 |
def greet(prompt):
|
55 |
-
return
|
56 |
|
57 |
ui = gr.Interface(
|
58 |
fn=greet,
|
@@ -61,7 +61,7 @@ def main():
|
|
61 |
)
|
62 |
ui.launch()
|
63 |
|
64 |
-
|
65 |
-
images =
|
66 |
|
67 |
main
|
|
|
27 |
# SD v1.4
|
28 |
def get_stable_diffusion_v14_pipeline():
|
29 |
model_id = "CompVis/stable-diffusion-v1-4"
|
30 |
+
pipe = StableDiffusionPipeline.from_pretrained(mode_id)
|
31 |
# pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
|
32 |
+
pipe = pipe.to(DEVICE)
|
33 |
torch.backends.cudnn.benchmark = True
|
34 |
+
return pipe
|
35 |
|
36 |
# SD v1.5
|
37 |
def get_stable_diffusion_v15_pipeline():
|
38 |
model_id = "runwayml/stable-diffusion-v1-5"
|
39 |
+
pipe = DiffusionPipeline.from_pretrained(mode_id)
|
40 |
+
pipe = pipe.to(DEVICE)
|
41 |
+
return pipe
|
42 |
|
43 |
def get_image(url):
|
44 |
response = requests.get(url)
|
|
|
50 |
def main():
|
51 |
prompt = "Hello world, I'm vizard,"
|
52 |
|
53 |
+
pipe = get_gpt2_pipeline()
|
54 |
def greet(prompt):
|
55 |
+
return pipe(prompt, max_length=50, num_return_sequences=3)
|
56 |
|
57 |
ui = gr.Interface(
|
58 |
fn=greet,
|
|
|
61 |
)
|
62 |
ui.launch()
|
63 |
|
64 |
+
pipe2 = get_stable_diffusion_v15_pipeline()
|
65 |
+
images = pipe2(prompt).images
|
66 |
|
67 |
main
|