Ziqi commited on
Commit
70ff35f
·
1 Parent(s): acc4e78
Files changed (3) hide show
  1. app.py +199 -14
  2. app_001.py +199 -0
  3. inference.py +81 -0
app.py CHANGED
@@ -54,10 +54,8 @@ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
54
  </center>
55
  '''
56
 
57
- os.system("git clone https://github.com/ziqihuangg/ReVersion")
58
- sys.path.append("ReVersion")
59
-
60
- from ReVersion.inference import *
61
 
62
  def show_warning(warning_text: str) -> gr.Blocks:
63
  with gr.Blocks() as demo:
@@ -72,6 +70,172 @@ def update_output_files() -> dict:
72
  return gr.update(value=paths or None)
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def find_weight_files() -> list[str]:
76
  curr_dir = pathlib.Path(__file__).parent
77
  paths = sorted(curr_dir.rglob('*.bin'))
@@ -88,8 +252,8 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
88
  with gr.Row():
89
  with gr.Column():
90
  base_model = gr.Dropdown(
91
- choices=['ReVersion/experiments/painted_on'],
92
- value='ReVersion/experiments/painted_on',
93
  label='Base Model',
94
  visible=True)
95
  resolution = gr.Dropdown(choices=[512, 768],
@@ -98,12 +262,12 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
98
  visible=True)
99
  reload_button = gr.Button('Reload Weight List')
100
  weight_name = gr.Dropdown(choices=find_weight_files(),
101
- value='ReVersion/experiments/painted_on',
102
- label='ReVersion/experiments/painted_on')
103
  prompt = gr.Textbox(
104
  label='Prompt',
105
  max_lines=1,
106
- placeholder='Example: "cat <R> stone"')
107
  seed = gr.Slider(label='Seed',
108
  minimum=0,
109
  maximum=100000,
@@ -175,6 +339,27 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
175
  return demo
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  pipe = InferencePipeline()
179
  trainer = Trainer()
180
 
@@ -189,12 +374,12 @@ with gr.Blocks(css='style.css') as demo:
189
  gr.Markdown(DETAILDESCRIPTION)
190
 
191
  with gr.Tabs():
192
- # with gr.TabItem('Train'):
193
- # create_training_demo(trainer, pipe)
194
- with gr.TabItem('Inference'):
195
  create_inference_demo(pipe)
196
- # with gr.TabItem('Upload'):
197
- # create_upload_demo()
198
 
199
  demo.queue(default_enabled=False).launch(share=False)
200
 
 
54
  </center>
55
  '''
56
 
57
+ os.system("git clone https://github.com/adobe-research/custom-diffusion")
58
+ sys.path.append("custom-diffusion")
 
 
59
 
60
  def show_warning(warning_text: str) -> gr.Blocks:
61
  with gr.Blocks() as demo:
 
70
  return gr.update(value=paths or None)
71
 
72
 
73
+ def create_training_demo(trainer: Trainer,
74
+ pipe: InferencePipeline) -> gr.Blocks:
75
+ with gr.Blocks() as demo:
76
+ base_model = gr.Dropdown(
77
+ choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
78
+ value='CompVis/stable-diffusion-v1-4',
79
+ label='Base Model',
80
+ visible=True)
81
+ resolution = gr.Dropdown(choices=['512', '768'],
82
+ value='512',
83
+ label='Resolution',
84
+ visible=True)
85
+
86
+ with gr.Row():
87
+ with gr.Box():
88
+ concept_images_collection = []
89
+ concept_prompt_collection = []
90
+ class_prompt_collection = []
91
+ buttons_collection = []
92
+ delete_collection = []
93
+ is_visible = []
94
+ maximum_concepts = 3
95
+ row = [None] * maximum_concepts
96
+ for x in range(maximum_concepts):
97
+ ordinal = lambda n: "%d%s" % (n, "tsnrhtdd"[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4])
98
+ ordinal_concept = ["<new1> cat", "<new2> wooden pot", "<new3> chair"]
99
+ if(x == 0):
100
+ visible = True
101
+ is_visible.append(gr.State(value=True))
102
+ else:
103
+ visible = False
104
+ is_visible.append(gr.State(value=False))
105
+
106
+ concept_images_collection.append(gr.Files(label=f'''Upload the images for your {ordinal(x+1) if (x>0) else ""} concept''', visible=visible))
107
+ with gr.Column(visible=visible) as row[x]:
108
+ concept_prompt_collection.append(
109
+ gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} concept prompt ''', max_lines=1,
110
+ placeholder=f'''Example: "photo of a {ordinal_concept[x]}"''' )
111
+ )
112
+ class_prompt_collection.append(
113
+ gr.Textbox(label=f'''{ordinal(x+1) if (x>0) else ""} class prompt ''',
114
+ max_lines=1, placeholder=f'''Example: "{ordinal_concept[x][7:]}"''')
115
+ )
116
+ with gr.Row():
117
+ if(x < maximum_concepts-1):
118
+ buttons_collection.append(gr.Button(value=f"Add {ordinal(x+2)} concept", visible=visible))
119
+ if(x > 0):
120
+ delete_collection.append(gr.Button(value=f"Delete {ordinal(x+1)} concept"))
121
+
122
+ counter_add = 1
123
+ for button in buttons_collection:
124
+ if(counter_add < len(buttons_collection)):
125
+ button.click(lambda:
126
+ [gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), True, None],
127
+ None,
128
+ [row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], buttons_collection[counter_add], is_visible[counter_add], concept_images_collection[counter_add]], queue=False)
129
+ else:
130
+ button.click(lambda:
131
+ [gr.update(visible=True),gr.update(visible=True), gr.update(visible=False), True],
132
+ None,
133
+ [row[counter_add], concept_images_collection[counter_add], buttons_collection[counter_add-1], is_visible[counter_add]], queue=False)
134
+ counter_add += 1
135
+
136
+ counter_delete = 1
137
+ for delete_button in delete_collection:
138
+ if(counter_delete < len(delete_collection)+1):
139
+ if counter_delete == 1:
140
+ delete_button.click(lambda:
141
+ [gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),False],
142
+ None,
143
+ [concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], buttons_collection[counter_delete], is_visible[counter_delete]], queue=False)
144
+ else:
145
+ delete_button.click(lambda:
146
+ [gr.update(visible=False, value=None),gr.update(visible=False), gr.update(visible=True), False],
147
+ None,
148
+ [concept_images_collection[counter_delete], row[counter_delete], buttons_collection[counter_delete-1], is_visible[counter_delete]], queue=False)
149
+ counter_delete += 1
150
+ gr.Markdown('''
151
+ - We use "\<new1\>" modifier_token in front of the concept, e.g., "\<new1\> cat". For multiple concepts use "\<new2\>", "\<new3\>" etc. Increase the number of steps with more concepts.
152
+ - For a new concept an e.g. concept prompt is "photo of a \<new1\> cat" and "cat" for class prompt.
153
+ - For a style concept, use "painting in the style of \<new1\> art" for concept prompt and "art" for class prompt.
154
+ - Class prompt should be the object category.
155
+ - If "Train Text Encoder", disable "modifier token" and use any unique text to describe the concept e.g. "ktn cat".
156
+ ''')
157
+ with gr.Box():
158
+ gr.Markdown('Training Parameters')
159
+ with gr.Row():
160
+ modifier_token = gr.Checkbox(label='modifier token',
161
+ value=True)
162
+ train_text_encoder = gr.Checkbox(label='Train Text Encoder',
163
+ value=False)
164
+ num_training_steps = gr.Number(
165
+ label='Number of Training Steps', value=1000, precision=0)
166
+ learning_rate = gr.Number(label='Learning Rate', value=0.00001)
167
+ batch_size = gr.Number(
168
+ label='batch_size', value=1, precision=0)
169
+ with gr.Row():
170
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
171
+ gradient_checkpointing = gr.Checkbox(label='Enable gradient checkpointing', value=False)
172
+ with gr.Accordion('Other Parameters', open=False):
173
+ gradient_accumulation = gr.Number(
174
+ label='Number of Gradient Accumulation',
175
+ value=1,
176
+ precision=0)
177
+ num_reg_images = gr.Number(
178
+ label='Number of Class Concept images',
179
+ value=200,
180
+ precision=0)
181
+ gen_images = gr.Checkbox(label='Generated images as regularization',
182
+ value=False)
183
+ gr.Markdown('''
184
+ - It will take about ~10 minutes to train for 1000 steps and ~21GB on a 3090 GPU.
185
+ - Our results in the paper are trained with batch-size 4 (8 including class regularization samples).
186
+ - Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
187
+ - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
188
+ - We retrieve real images for class concept using clip_retireval library which can take some time.
189
+ ''')
190
+
191
+ run_button = gr.Button('Start Training')
192
+ with gr.Box():
193
+ with gr.Row():
194
+ check_status_button = gr.Button('Check Training Status')
195
+ with gr.Column():
196
+ with gr.Box():
197
+ gr.Markdown('Message')
198
+ training_status = gr.Markdown()
199
+ output_files = gr.Files(label='Trained Weight Files')
200
+
201
+ run_button.click(fn=pipe.clear,
202
+ inputs=None,
203
+ outputs=None,)
204
+ run_button.click(fn=trainer.run,
205
+ inputs=[
206
+ base_model,
207
+ resolution,
208
+ num_training_steps,
209
+ learning_rate,
210
+ train_text_encoder,
211
+ modifier_token,
212
+ gradient_accumulation,
213
+ batch_size,
214
+ use_8bit_adam,
215
+ gradient_checkpointing,
216
+ gen_images,
217
+ num_reg_images,
218
+ ] +
219
+ concept_images_collection +
220
+ concept_prompt_collection +
221
+ class_prompt_collection
222
+ ,
223
+ outputs=[
224
+ training_status,
225
+ output_files,
226
+ ],
227
+ queue=False)
228
+ check_status_button.click(fn=trainer.check_if_running,
229
+ inputs=None,
230
+ outputs=training_status,
231
+ queue=False)
232
+ check_status_button.click(fn=update_output_files,
233
+ inputs=None,
234
+ outputs=output_files,
235
+ queue=False)
236
+ return demo
237
+
238
+
239
  def find_weight_files() -> list[str]:
240
  curr_dir = pathlib.Path(__file__).parent
241
  paths = sorted(curr_dir.rglob('*.bin'))
 
252
  with gr.Row():
253
  with gr.Column():
254
  base_model = gr.Dropdown(
255
+ choices=['stabilityai/stable-diffusion-2-1-base', 'CompVis/stable-diffusion-v1-4'],
256
+ value='CompVis/stable-diffusion-v1-4',
257
  label='Base Model',
258
  visible=True)
259
  resolution = gr.Dropdown(choices=[512, 768],
 
262
  visible=True)
263
  reload_button = gr.Button('Reload Weight List')
264
  weight_name = gr.Dropdown(choices=find_weight_files(),
265
+ value='custom-diffusion-models/cat.bin',
266
+ label='Custom Diffusion Weight File')
267
  prompt = gr.Textbox(
268
  label='Prompt',
269
  max_lines=1,
270
+ placeholder='Example: "\<new1\> cat in outer space"')
271
  seed = gr.Slider(label='Seed',
272
  minimum=0,
273
  maximum=100000,
 
339
  return demo
340
 
341
 
342
+ def create_upload_demo() -> gr.Blocks:
343
+ with gr.Blocks() as demo:
344
+ model_name = gr.Textbox(label='Model Name')
345
+ hf_token = gr.Textbox(
346
+ label='Hugging Face Token (with write permission)')
347
+ upload_button = gr.Button('Upload')
348
+ with gr.Box():
349
+ gr.Markdown('Message')
350
+ result = gr.Markdown()
351
+ gr.Markdown('''
352
+ - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
353
+ - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
354
+ ''')
355
+
356
+ upload_button.click(fn=upload,
357
+ inputs=[model_name, hf_token],
358
+ outputs=result)
359
+
360
+ return demo
361
+
362
+
363
  pipe = InferencePipeline()
364
  trainer = Trainer()
365
 
 
374
  gr.Markdown(DETAILDESCRIPTION)
375
 
376
  with gr.Tabs():
377
+ with gr.TabItem('Train'):
378
+ create_training_demo(trainer, pipe)
379
+ with gr.TabItem('Test'):
380
  create_inference_demo(pipe)
381
+ with gr.TabItem('Upload'):
382
+ create_upload_demo()
383
 
384
  demo.queue(default_enabled=False).launch(share=False)
385
 
app_001.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Demo app for https://github.com/adobe-research/custom-diffusion.
3
+ The code in this repo is partly adapted from the following repository:
4
+ https://huggingface.co/spaces/hysts/LoRA-SD-training
5
+ MIT License
6
+ Copyright (c) 2022 hysts
7
+ ==========================================================================================
8
+ Adobe’s modifications are Copyright 2022 Adobe Research. All rights reserved.
9
+ Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
10
+ LICENSE.
11
+ ==========================================================================================
12
+ """
13
+
14
+ from __future__ import annotations
15
+ import sys
16
+ import os
17
+ import pathlib
18
+
19
+ import gradio as gr
20
+ import torch
21
+
22
+ from inference import InferencePipeline
23
+ from trainer import Trainer
24
+ from uploader import upload
25
+
26
+ TITLE = '# Custom Diffusion + StableDiffusion Training UI'
27
+ DESCRIPTION = '''This is a demo for [https://github.com/adobe-research/custom-diffusion](https://github.com/adobe-research/custom-diffusion).
28
+ It is recommended to upgrade to GPU in Settings after duplicating this space to use it.
29
+ <a href="https://huggingface.co/spaces/nupurkmr9/custom-diffusion?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
30
+ '''
31
+ DETAILDESCRIPTION='''
32
+ Custom Diffusion allows you to fine-tune text-to-image diffusion models, such as Stable Diffusion, given a few images of a new concept (~4-20).
33
+ We fine-tune only a subset of model parameters, namely key and value projection matrices, in the cross-attention layers and the modifier token used to represent the object.
34
+ This also reduces the extra storage for each additional concept to 75MB. Our method also allows you to use a combination of concepts. There's still limitations on which compositions work. For more analysis please refer to our [website](https://www.cs.cmu.edu/~custom-diffusion/).
35
+ <center>
36
+ <img src="https://huggingface.co/spaces/nupurkmr9/custom-diffusion/resolve/main/method.jpg" width="600" align="center" >
37
+ </center>
38
+ '''
39
+
40
+ ORIGINAL_SPACE_ID = 'nupurkmr9/custom-diffusion'
41
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
42
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
43
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
44
+ '''
45
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
46
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
47
+
48
+ else:
49
+ SETTINGS = 'Settings'
50
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
51
+ <center>
52
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
53
+ "T4 small" is sufficient to run this demo.
54
+ </center>
55
+ '''
56
+
57
+ os.system("git clone https://github.com/ziqihuangg/ReVersion")
58
+ sys.path.append("ReVersion")
59
+
60
+
61
+ def show_warning(warning_text: str) -> gr.Blocks:
62
+ with gr.Blocks() as demo:
63
+ with gr.Box():
64
+ gr.Markdown(warning_text)
65
+ return demo
66
+
67
+
68
+ def update_output_files() -> dict:
69
+ paths = sorted(pathlib.Path('results').glob('*.bin'))
70
+ paths = [path.as_posix() for path in paths] # type: ignore
71
+ return gr.update(value=paths or None)
72
+
73
+
74
+ def find_weight_files() -> list[str]:
75
+ curr_dir = pathlib.Path(__file__).parent
76
+ paths = sorted(curr_dir.rglob('*.bin'))
77
+ paths = [path for path in paths if '.lfs' not in str(path)]
78
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
79
+
80
+
81
+ def reload_custom_diffusion_weight_list() -> dict:
82
+ return gr.update(choices=find_weight_files())
83
+
84
+
85
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
86
+ with gr.Blocks() as demo:
87
+ with gr.Row():
88
+ with gr.Column():
89
+ base_model = gr.Dropdown(
90
+ choices=['ReVersion/experiments/painted_on'],
91
+ value='ReVersion/experiments/painted_on',
92
+ label='Base Model',
93
+ visible=True)
94
+ resolution = gr.Dropdown(choices=[512, 768],
95
+ value=512,
96
+ label='Resolution',
97
+ visible=True)
98
+ reload_button = gr.Button('Reload Weight List')
99
+ weight_name = gr.Dropdown(choices=find_weight_files(),
100
+ value='ReVersion/experiments/painted_on',
101
+ label='ReVersion/experiments/painted_on')
102
+ prompt = gr.Textbox(
103
+ label='Prompt',
104
+ max_lines=1,
105
+ placeholder='Example: "cat <R> stone"')
106
+ seed = gr.Slider(label='Seed',
107
+ minimum=0,
108
+ maximum=100000,
109
+ step=1,
110
+ value=42)
111
+ with gr.Accordion('Other Parameters', open=False):
112
+ num_steps = gr.Slider(label='Number of Steps',
113
+ minimum=0,
114
+ maximum=500,
115
+ step=1,
116
+ value=100)
117
+ guidance_scale = gr.Slider(label='CFG Scale',
118
+ minimum=0,
119
+ maximum=50,
120
+ step=0.1,
121
+ value=6)
122
+ eta = gr.Slider(label='DDIM eta',
123
+ minimum=0,
124
+ maximum=1.,
125
+ step=0.1,
126
+ value=1.)
127
+ batch_size = gr.Slider(label='Batch Size',
128
+ minimum=0,
129
+ maximum=10.,
130
+ step=1,
131
+ value=1)
132
+
133
+ run_button = gr.Button('Generate')
134
+
135
+ gr.Markdown('''
136
+ - Models with names starting with "custom-diffusion-models/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/delta.bin" are your trained models.
137
+ - After training, you can press "Reload Weight List" button to load your trained model names.
138
+ - Increase number of steps in Other parameters for better samples qualitatively.
139
+ ''')
140
+ with gr.Column():
141
+ result = gr.Image(label='Result')
142
+
143
+ reload_button.click(fn=reload_custom_diffusion_weight_list,
144
+ inputs=None,
145
+ outputs=weight_name)
146
+ prompt.submit(fn=pipe.run,
147
+ inputs=[
148
+ base_model,
149
+ weight_name,
150
+ prompt,
151
+ seed,
152
+ num_steps,
153
+ guidance_scale,
154
+ eta,
155
+ batch_size,
156
+ resolution
157
+ ],
158
+ outputs=result,
159
+ queue=False)
160
+ run_button.click(fn=pipe.run,
161
+ inputs=[
162
+ base_model,
163
+ weight_name,
164
+ prompt,
165
+ seed,
166
+ num_steps,
167
+ guidance_scale,
168
+ eta,
169
+ batch_size,
170
+ resolution
171
+ ],
172
+ outputs=result,
173
+ queue=False)
174
+ return demo
175
+
176
+
177
+ pipe = InferencePipeline()
178
+ trainer = Trainer()
179
+
180
+ with gr.Blocks(css='style.css') as demo:
181
+ if os.getenv('IS_SHARED_UI'):
182
+ show_warning(SHARED_UI_WARNING)
183
+ if not torch.cuda.is_available():
184
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
185
+
186
+ gr.Markdown(TITLE)
187
+ gr.Markdown(DESCRIPTION)
188
+ gr.Markdown(DETAILDESCRIPTION)
189
+
190
+ with gr.Tabs():
191
+ # with gr.TabItem('Train'):
192
+ # create_training_demo(trainer, pipe)
193
+ with gr.TabItem('Inference'):
194
+ create_inference_demo(pipe)
195
+ # with gr.TabItem('Upload'):
196
+ # create_upload_demo()
197
+
198
+ demo.queue(default_enabled=False).launch(share=False)
199
+
inference.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+
7
+ import gradio as gr
8
+ import PIL.Image
9
+ import numpy as np
10
+
11
+ import torch
12
+ from diffusers import StableDiffusionPipeline
13
+ sys.path.insert(0, './ReVersion')
14
+
15
+
16
+ class InferencePipeline:
17
+ def __init__(self):
18
+ self.pipe = None
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.weight_path = None
22
+
23
+ def clear(self) -> None:
24
+ self.weight_path = None
25
+ del self.pipe
26
+ self.pipe = None
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+ @staticmethod
31
+ def get_weight_path(name: str) -> pathlib.Path:
32
+ curr_dir = pathlib.Path(__file__).parent
33
+ return curr_dir / name
34
+
35
+ def load_pipe(self, model_id: str, filename: str) -> None:
36
+ weight_path = self.get_weight_path(filename)
37
+ if weight_path == self.weight_path:
38
+ return
39
+ self.weight_path = weight_path
40
+ weight = torch.load(self.weight_path, map_location=self.device)
41
+
42
+ if self.device.type == 'cpu':
43
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
44
+ else:
45
+ pipe = StableDiffusionPipeline.from_pretrained(
46
+ model_id, torch_dtype=torch.float16)
47
+ pipe = pipe.to(self.device)
48
+
49
+ from src import diffuser_training
50
+ diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, compress=False)
51
+
52
+ self.pipe = pipe
53
+
54
+ def run(
55
+ self,
56
+ base_model: str,
57
+ weight_name: str,
58
+ prompt: str,
59
+ seed: int,
60
+ n_steps: int,
61
+ guidance_scale: float,
62
+ eta: float,
63
+ batch_size: int,
64
+ resolution: int,
65
+ ) -> PIL.Image.Image:
66
+ if not torch.cuda.is_available():
67
+ raise gr.Error('CUDA is not available.')
68
+
69
+ self.load_pipe(base_model, weight_name)
70
+
71
+ generator = torch.Generator(device=self.device).manual_seed(seed)
72
+ out = self.pipe([prompt]*batch_size,
73
+ num_inference_steps=n_steps,
74
+ guidance_scale=guidance_scale,
75
+ height=resolution, width=resolution,
76
+ eta = eta,
77
+ generator=generator) # type: ignore
78
+ torch.cuda.empty_cache()
79
+ out = out.images
80
+ out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
81
+ return out