cbensimon HF Staff commited on
Commit
24ec8bb
·
1 Parent(s): ce8b907
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -1,15 +1,12 @@
1
  """
2
  """
3
- from datetime import datetime
4
- t0 = datetime.now()
5
-
6
  # Upgrade PyTorch
7
  import os
8
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
9
- print('torch upgrade', -(t0 - (t0 := datetime.now())))
10
 
11
  # Actual app.py
12
  import os
 
13
 
14
  import gradio as gr
15
  import spaces
@@ -22,7 +19,6 @@ from zerogpu import aoti_compile
22
 
23
 
24
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
25
- print('FluxPipeline.from_pretrained', -(t0 - (t0 := datetime.now())))
26
 
27
 
28
  @spaces.GPU(duration=1500)
@@ -70,17 +66,16 @@ transformer_config = pipeline.transformer.config
70
  pipeline.transformer = compile_transformer()
71
  pipeline.transformer.config = transformer_config
72
 
 
73
  @spaces.GPU
74
- def _generate_image(prompt: str, t0: datetime):
75
- print('@spaces.GPU', -(t0 - (t0 := datetime.now())))
76
  images = []
77
  for _ in range(4):
78
- images += pipeline(prompt, num_inference_steps=4).images
79
- print('pipeline', -(t0 - (t0 := datetime.now())))
 
80
  return images
81
 
82
 
83
- def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
84
- return _generate_image(prompt, datetime.now())
85
-
86
- gr.Interface(generate_image, gr.Text(), gr.Gallery()).launch(show_error=True)
 
1
  """
2
  """
 
 
 
3
  # Upgrade PyTorch
4
  import os
5
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
 
6
 
7
  # Actual app.py
8
  import os
9
+ from datetime import datetime
10
 
11
  import gradio as gr
12
  import spaces
 
19
 
20
 
21
  pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda')
 
22
 
23
 
24
  @spaces.GPU(duration=1500)
 
66
  pipeline.transformer = compile_transformer()
67
  pipeline.transformer.config = transformer_config
68
 
69
+
70
  @spaces.GPU
71
+ def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
72
+ t0 = datetime.now()
73
  images = []
74
  for _ in range(4):
75
+ image = pipeline(prompt, num_inference_steps=4).images[0]
76
+ elapsed = -(t0 - (t0 := datetime.now()))
77
+ images += [(image, f'{elapsed.total_seconds():.2f}s')]
78
  return images
79
 
80
 
81
+ gr.Interface(generate_image, gr.Text(), gr.Gallery(rows=3, columns=3, height='60vh')).launch()