vs4vijay commited on
Commit
460180a
·
1 Parent(s): a0bf13b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -27,18 +27,18 @@ def get_gpt2_pipeline():
27
  # SD v1.4
28
  def get_stable_diffusion_v14_pipeline():
29
  model_id = "CompVis/stable-diffusion-v1-4"
30
- pipeline = StableDiffusionPipeline.from_pretrained(mode_id)
31
  # pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
32
- pipeline = pipeline.to(DEVICE)
33
  torch.backends.cudnn.benchmark = True
34
- return pipeline
35
 
36
  # SD v1.5
37
  def get_stable_diffusion_v15_pipeline():
38
  model_id = "runwayml/stable-diffusion-v1-5"
39
- pipeline = DiffusionPipeline.from_pretrained(mode_id)
40
- pipeline = pipeline.to(DEVICE)
41
- return pipeline
42
 
43
  def get_image(url):
44
  response = requests.get(url)
@@ -50,9 +50,9 @@ def get_image(url):
50
  def main():
51
  prompt = "Hello world, I'm vizard,"
52
 
53
- pipeline = get_gpt2_pipeline()
54
  def greet(prompt):
55
- return pipeline(prompt, max_length=50, num_return_sequences=3)
56
 
57
  ui = gr.Interface(
58
  fn=greet,
@@ -61,7 +61,7 @@ def main():
61
  )
62
  ui.launch()
63
 
64
- pipeline = get_stable_diffusion_v15_pipeline()
65
- images = pipeline(prompt).images
66
 
67
  main
 
27
  # SD v1.4
28
  def get_stable_diffusion_v14_pipeline():
29
  model_id = "CompVis/stable-diffusion-v1-4"
30
+ pipe = StableDiffusionPipeline.from_pretrained(mode_id)
31
  # pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
32
+ pipe = pipe.to(DEVICE)
33
  torch.backends.cudnn.benchmark = True
34
+ return pipe
35
 
36
  # SD v1.5
37
  def get_stable_diffusion_v15_pipeline():
38
  model_id = "runwayml/stable-diffusion-v1-5"
39
+ pipe = DiffusionPipeline.from_pretrained(mode_id)
40
+ pipe = pipe.to(DEVICE)
41
+ return pipe
42
 
43
  def get_image(url):
44
  response = requests.get(url)
 
50
  def main():
51
  prompt = "Hello world, I'm vizard,"
52
 
53
+ pipe = get_gpt2_pipeline()
54
  def greet(prompt):
55
+ return pipe(prompt, max_length=50, num_return_sequences=3)
56
 
57
  ui = gr.Interface(
58
  fn=greet,
 
61
  )
62
  ui.launch()
63
 
64
+ pipe2 = get_stable_diffusion_v15_pipeline()
65
+ images = pipe2(prompt).images
66
 
67
  main