sWizad sincostanx commited on
Commit
3d91cf3
·
0 Parent(s):

Duplicate from sincostanx/momentum-diffusion

Browse files

Co-authored-by: Worameth Chinchuthakun <[email protected]>

Files changed (6) hide show
  1. .gitattributes +34 -0
  2. README.md +15 -0
  3. app.py +227 -0
  4. momentum_scheduler.py +385 -0
  5. pipeline.py +236 -0
  6. requirements.txt +96 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Momentum Diffusion
3
+ emoji: 🐠
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.36.1
8
+ python_version: 3.9.17
9
+ app_file: app.py
10
+ pinned: false
11
+ license: cc-by-4.0
12
+ duplicated_from: sincostanx/momentum-diffusion
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from pipeline import CustomPipeline, setup_scheduler
4
+ from diffusers import StableDiffusionPipeline
5
+ from PIL import Image
6
+ # from easydict import EasyDict as edict
7
+
8
+ original_pipe = None
9
+ original_config = None
10
+ device = None
11
+
12
+
13
+ # def run_dpm_demo(id, prompt, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
14
+ def run_dpm_demo(prompt, beta, num_inference_steps, guidance_scale, seed):
15
+ global original_pipe, original_config
16
+ pipe = CustomPipeline(**original_pipe.components)
17
+
18
+ seed = int(seed)
19
+ num_inference_steps = int(num_inference_steps)
20
+
21
+ scheduler = "DPM-Solver++"
22
+ params = {
23
+ "prompt": prompt,
24
+ "num_inference_steps": num_inference_steps,
25
+ "guidance_scale": guidance_scale,
26
+ "method": "dpm"
27
+ }
28
+
29
+ # without momentum (equivalent to DPM-Solver++)
30
+ pipe = setup_scheduler(pipe, scheduler, beta=1.0, original_config=original_config)
31
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
32
+ ori_image = pipe(**params).images[0]
33
+
34
+ # with momentum
35
+ pipe = setup_scheduler(pipe, scheduler, beta=beta, original_config=original_config)
36
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
37
+ image = pipe(**params).images[0]
38
+
39
+ ori_image.save("temp1.png")
40
+ image.save("temp2.png")
41
+
42
+ return [ori_image, image]
43
+
44
+ # def run_plms_demo(id, prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed, enable_token_merging):
45
+ def run_plms_demo(prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed):
46
+ global original_pipe, original_config
47
+ pipe = CustomPipeline(**original_pipe.components)
48
+
49
+ seed = int(seed)
50
+ num_inference_steps = int(num_inference_steps)
51
+
52
+ scheduler = "PLMS"
53
+ method = "hb" if momentum_type == "Polyak's heavy ball" else "nt"
54
+ params = {
55
+ "prompt": prompt,
56
+ "num_inference_steps": num_inference_steps,
57
+ "guidance_scale": guidance_scale,
58
+ "method": method
59
+ }
60
+
61
+ # without momentum (equivalent to PLMS)
62
+ pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=1.0, original_config=original_config)
63
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
64
+ ori_image = pipe(**params).images[0]
65
+
66
+ # with momentum
67
+ pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=beta, original_config=original_config)
68
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
69
+ image = pipe(**params).images[0]
70
+
71
+ return [ori_image, image]
72
+
73
+ # def run_ghvb_demo(id, prompt, order, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
74
+ def run_ghvb_demo(prompt, order, beta, num_inference_steps, guidance_scale, seed):
75
+ global original_pipe, original_config
76
+ pipe = CustomPipeline(**original_pipe.components)
77
+
78
+ seed = int(seed)
79
+ num_inference_steps = int(num_inference_steps)
80
+
81
+ scheduler = "GHVB"
82
+ params = {
83
+ "prompt": prompt,
84
+ "num_inference_steps": num_inference_steps,
85
+ "guidance_scale": guidance_scale,
86
+ "method": "ghvb"
87
+ }
88
+
89
+ # without momentum (equivalent to PLMS)
90
+ pipe = setup_scheduler(pipe, scheduler, order=order, beta=1.0, original_config=original_config)
91
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
92
+ ori_image = pipe(**params).images[0]
93
+
94
+ # with momentum
95
+ pipe = setup_scheduler(pipe, scheduler, order=order, beta=beta, original_config=original_config)
96
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
97
+ image = pipe(**params).images[0]
98
+
99
+ return [ori_image, image]
100
+
101
+ if __name__ == "__main__":
102
+
103
+ demo = gr.Blocks()
104
+
105
+ inputs = {}
106
+ outputs = {}
107
+ buttons = {}
108
+
109
+ list_models = [
110
+ "Linaqruf/anything-v3.0",
111
+ "runwayml/stable-diffusion-v1-5",
112
+ "dreamlike-art/dreamlike-photoreal-2.0",
113
+ ]
114
+ for model_id in list_models:
115
+ pipeline = StableDiffusionPipeline.from_pretrained(model_id)
116
+ del pipeline
117
+ print(f"Downloaded {model_id}")
118
+
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown(
121
+ """
122
+ # Momentum-Diffusion Demo
123
+
124
+ A novel sampling method for diffusion models based on momentum to reduce artifacts
125
+
126
+ """
127
+ )
128
+ id = gr.Dropdown(list_models, label="Model ID", value="Linaqruf/anything-v3.0", allow_custom_value=True)
129
+ enable_token_merging = gr.Checkbox(label="Enable Token Merging", value=False)
130
+ # output = gr.Textbox()
131
+ buttons["select_model"] = gr.Button("Select")
132
+
133
+ with gr.Tab("GHVB", visible=False) as tab3:
134
+ prompt3 = gr.Textbox(label="Prompt", value="a cozy cafe", visible=False)
135
+
136
+ with gr.Row(visible=False) as row31:
137
+ order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
138
+ beta = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, label="beta")
139
+ num_inference_steps = gr.Number(label="Number of steps", value=12)
140
+ guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
141
+ seed = gr.Number(label="Seed", value=42)
142
+
143
+ with gr.Row(visible=False) as row32:
144
+ out1 = gr.Image(label="PLMS", interactive=False)
145
+ out2 = gr.Image(label="GHVB", interactive=False)
146
+
147
+ inputs["GHVB"] = [prompt3, order, beta, num_inference_steps, guidance_scale, seed]
148
+ outputs["GHVB"] = [out1, out2]
149
+ buttons["GHVB"] = gr.Button("Sample", visible=False)
150
+
151
+ with gr.Tab("PLMS", visible=False) as tab2:
152
+ prompt2 = gr.Textbox(label="Prompt", value="1girl", visible=False)
153
+
154
+ with gr.Row(visible=False) as row21:
155
+ order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
156
+ beta = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, label="beta")
157
+ momentum_type = gr.Dropdown(["Polyak's heavy ball", "Nesterov"], label="Momentum Type", value="Polyak's heavy ball")
158
+ num_inference_steps = gr.Number(label="Number of steps", value=10)
159
+ guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
160
+ seed = gr.Number(label="Seed", value=42)
161
+
162
+ with gr.Row(visible=False) as row22:
163
+ out1 = gr.Image(label="Without momentum", interactive=False)
164
+ out2 = gr.Image(label="With momentum", interactive=False)
165
+
166
+ inputs["PLMS"] = [prompt2, order, beta, momentum_type, num_inference_steps, guidance_scale, seed]
167
+ outputs["PLMS"] = [out1, out2]
168
+ buttons["PLMS"] = gr.Button("Sample", visible=False)
169
+
170
+ with gr.Tab("DPM-Solver++", visible=False) as tab1:
171
+ prompt1 = gr.Textbox(label="Prompt", value="1girl", visible=False)
172
+
173
+ with gr.Row(visible=False) as row11:
174
+ beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="beta")
175
+ num_inference_steps = gr.Number(label="Number of steps", value=15)
176
+ guidance_scale = gr.Number(label="Guidance scale (cfg)", value=20)
177
+ seed = gr.Number(label="Seed", value=0)
178
+
179
+ with gr.Row(visible=False) as row12:
180
+ out1 = gr.Image(label="Without momentum", interactive=False)
181
+ out2 = gr.Image(label="With momentum", interactive=False)
182
+
183
+ inputs["DPM-Solver++"] = [prompt1, beta, num_inference_steps, guidance_scale, seed]
184
+ outputs["DPM-Solver++"] = [out1, out2]
185
+ buttons["DPM-Solver++"] = gr.Button("Sample", visible=False)
186
+
187
+ def prepare_model(id, enable_token_merging):
188
+ global original_pipe, original_config, device
189
+
190
+ if original_pipe is not None:
191
+ del original_pipe
192
+
193
+ original_pipe = CustomPipeline.from_pretrained(id)
194
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
195
+ original_pipe = original_pipe.to(device)
196
+
197
+ if enable_token_merging:
198
+ import tomesd
199
+ tomesd.apply_patch(original_pipe, ratio=0.5)
200
+ print("Enabled Token merging.")
201
+
202
+ original_config = original_pipe.scheduler.config
203
+ print(type(original_pipe))
204
+ print(original_config)
205
+
206
+ return {
207
+ row11: gr.update(visible=True),
208
+ row12: gr.update(visible=True),
209
+ row21: gr.update(visible=True),
210
+ row22: gr.update(visible=True),
211
+ row31: gr.update(visible=True),
212
+ row32: gr.update(visible=True),
213
+ prompt1: gr.update(visible=True),
214
+ prompt2: gr.update(visible=True),
215
+ prompt3: gr.update(visible=True),
216
+ buttons["DPM-Solver++"]: gr.update(visible=True),
217
+ buttons["PLMS"]: gr.update(visible=True),
218
+ buttons["GHVB"]: gr.update(visible=True),
219
+ }
220
+
221
+ all_outputs = [row11, row12, row21, row22, row31, row32, prompt1, prompt2, prompt3, buttons["DPM-Solver++"], buttons["PLMS"], buttons["GHVB"]]
222
+ buttons["select_model"].click(prepare_model, inputs=[id, enable_token_merging], outputs=all_outputs)
223
+ buttons["DPM-Solver++"].click(run_dpm_demo, inputs=inputs["DPM-Solver++"], outputs=outputs["DPM-Solver++"])
224
+ buttons["PLMS"].click(run_plms_demo, inputs=inputs["PLMS"], outputs=outputs["PLMS"])
225
+ buttons["GHVB"].click(run_ghvb_demo, inputs=inputs["GHVB"], outputs=outputs["GHVB"])
226
+
227
+ demo.launch()
momentum_scheduler.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler
3
+ from typing import List
4
+
5
+ def AdamBmixer(order, ets, b=1):
6
+
7
+ cur_order = min(order, len(ets))
8
+ if cur_order == 1:
9
+ prime = b * ets[-1]
10
+ elif cur_order == 2:
11
+ prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2
12
+ elif cur_order == 3:
13
+ prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12
14
+ elif cur_order == 4:
15
+ prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24
16
+ elif cur_order == 5:
17
+ prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2]
18
+ + (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4]
19
+ + (270-19*b)* ets[-5]) / 720
20
+ else:
21
+ raise NotImplementedError
22
+
23
+ prime = prime/b
24
+ return prime
25
+
26
+ class PLMSWithHBScheduler():
27
+ """
28
+ PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs.
29
+ We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
30
+
31
+ When order is an integer, this method is equivalent to PLMS without momentum.
32
+ """
33
+ def __init__(self, scheduler, order):
34
+ self.scheduler = scheduler
35
+ self.ets = []
36
+ self.update_order(order)
37
+ self.mixer = AdamBmixer
38
+
39
+ def update_order(self, order):
40
+ self.order = order // 1 + 1 if order%1 > 0 else order // 1
41
+ self.beta = order % 1 if order%1 > 0 else 1
42
+ self.vel = None
43
+
44
+ def clear(self):
45
+ self.ets = []
46
+ self.vel = None
47
+
48
+ def update_ets(self, val):
49
+ self.ets.append(val)
50
+ if len(self.ets) > self.order:
51
+ self.ets.pop(0)
52
+
53
+ def _step_with_momentum(self, grads):
54
+ self.update_ets(grads)
55
+ prime = self.mixer(self.order, self.ets, 1.0)
56
+ self.vel = (1 - self.beta) * self.vel + self.beta * prime
57
+ return self.vel
58
+
59
+ def step(
60
+ self,
61
+ grads: torch.FloatTensor,
62
+ timestep: int,
63
+ latents: torch.FloatTensor,
64
+ output_mode: str = "scale",
65
+ ):
66
+ if self.vel is None: self.vel = grads
67
+
68
+ if hasattr(self.scheduler, 'sigmas'):
69
+ step_index = (self.scheduler.timesteps == timestep).nonzero().item()
70
+ sigma = self.scheduler.sigmas[step_index]
71
+ sigma_next = self.scheduler.sigmas[step_index + 1]
72
+ del_g = sigma_next - sigma
73
+
74
+ update_val = self._step_with_momentum(grads)
75
+ return latents + del_g * update_val
76
+
77
+ elif isinstance(self.scheduler, DPMSolverMultistepScheduler):
78
+ step_index = (self.scheduler.timesteps == timestep).nonzero().item()
79
+ current_timestep = self.scheduler.timesteps[step_index]
80
+ prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1]
81
+
82
+ alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep]
83
+ alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep]
84
+
85
+ s0 = torch.sqrt(alpha_prod_t)
86
+ s_1 = torch.sqrt(alpha_bar_prev)
87
+ g0 = torch.sqrt(1-alpha_prod_t)/s0
88
+ g_1 = torch.sqrt(1-alpha_bar_prev)/s_1
89
+ del_g = g_1 - g0
90
+
91
+ update_val = self._step_with_momentum(grads)
92
+ if output_mode in ["scale"]:
93
+ return (latents/s0 + del_g * update_val) * s_1
94
+ elif output_mode in ["back"]:
95
+ return latents + del_g * update_val * s_1
96
+ elif output_mode in ["front"]:
97
+ return latents + del_g * update_val * s0
98
+ else:
99
+ return latents + del_g * update_val
100
+ else:
101
+ raise NotImplementedError
102
+
103
+ class GHVBScheduler(PLMSWithHBScheduler):
104
+ """
105
+ Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs.
106
+ We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
107
+
108
+ When order is an integer, this method is equivalent to PLMS without momentum.
109
+ """
110
+ def _step_with_momentum(self, grads):
111
+ self.vel = (1 - self.beta) * self.vel + self.beta * grads
112
+ self.update_ets(self.vel)
113
+ prime = self.mixer(self.order, self.ets, self.beta)
114
+ return prime
115
+
116
+ class PLMSWithNTScheduler(PLMSWithHBScheduler):
117
+ """
118
+ PLMS with Nesterov Momentum (NT) for diffusion ODEs.
119
+ We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
120
+
121
+ When order is an integer, this method is equivalent to PLMS without momentum.
122
+ """
123
+ def _step_with_momentum(self, grads):
124
+ self.update_ets(grads)
125
+ prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)}
126
+ self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)}
127
+ update_val = (1 - self.beta) * self.vel + self.beta * prime # update x
128
+ return update_val
129
+
130
+ class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
131
+ """
132
+ DPM-Solver++2M with HB momentum.
133
+ Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint"
134
+
135
+ When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum.
136
+ """
137
+ def initialize_momentum(self, beta):
138
+ self.vel = None
139
+ self.beta = beta
140
+
141
+ def multistep_dpm_solver_second_order_update(
142
+ self,
143
+ model_output_list: List[torch.FloatTensor],
144
+ timestep_list: List[int],
145
+ prev_timestep: int,
146
+ sample: torch.FloatTensor,
147
+ ) -> torch.FloatTensor:
148
+
149
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
150
+ m0, m1 = model_output_list[-1], model_output_list[-2]
151
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
152
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
153
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
154
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
155
+ r0 = h_0 / h
156
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
157
+ if self.config.algorithm_type == "dpmsolver++":
158
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
159
+ if self.config.solver_type == "midpoint":
160
+ diff = (D0 + 0.5 * D1)
161
+
162
+ if self.vel is None:
163
+ self.vel = diff
164
+ else:
165
+ self.vel = (1-self.beta)*self.vel + self.beta * diff
166
+
167
+ x_t = (
168
+ (sigma_t / sigma_s0) * sample
169
+ - (alpha_t * (torch.exp(-h) - 1.0)) * self.vel
170
+ )
171
+ elif self.config.solver_type == "heun":
172
+ raise NotImplementedError(
173
+ "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
174
+ )
175
+ elif self.config.algorithm_type == "dpmsolver":
176
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
177
+ if self.config.solver_type == "midpoint":
178
+ raise NotImplementedError(
179
+ "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
180
+ )
181
+ elif self.config.solver_type == "heun":
182
+ raise NotImplementedError(
183
+ "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
184
+ )
185
+ return x_t
186
+
187
+ class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler):
188
+ """
189
+ UniPC with HB momentum.
190
+ Currently support only self.predict_x0 = True
191
+
192
+ When beta = 1.0, this method is equivalent to UniPC without momentum.
193
+ """
194
+ def initialize_momentum(self, beta):
195
+ self.vel_p = None
196
+ self.vel_c = None
197
+ self.beta = beta
198
+
199
+ def multistep_uni_p_bh_update(
200
+ self,
201
+ model_output: torch.FloatTensor,
202
+ prev_timestep: int,
203
+ sample: torch.FloatTensor,
204
+ order: int,
205
+ ) -> torch.FloatTensor:
206
+
207
+ timestep_list = self.timestep_list
208
+ model_output_list = self.model_outputs
209
+
210
+ s0, t = self.timestep_list[-1], prev_timestep
211
+ m0 = model_output_list[-1]
212
+ x = sample
213
+
214
+ if self.solver_p:
215
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
216
+ return x_t
217
+
218
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
219
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
220
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
221
+
222
+ h = lambda_t - lambda_s0
223
+ device = sample.device
224
+
225
+ rks = []
226
+ D1s = []
227
+ for i in range(1, order):
228
+ si = timestep_list[-(i + 1)]
229
+ mi = model_output_list[-(i + 1)]
230
+ lambda_si = self.lambda_t[si]
231
+ rk = (lambda_si - lambda_s0) / h
232
+ rks.append(rk)
233
+ D1s.append((mi - m0) / rk)
234
+
235
+ rks.append(1.0)
236
+ rks = torch.tensor(rks, device=device)
237
+
238
+ R = []
239
+ b = []
240
+
241
+ hh = -h if self.predict_x0 else h
242
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
243
+ h_phi_k = h_phi_1 / hh - 1
244
+
245
+ factorial_i = 1
246
+
247
+ if self.config.solver_type == "bh1":
248
+ B_h = hh
249
+ elif self.config.solver_type == "bh2":
250
+ B_h = torch.expm1(hh)
251
+ else:
252
+ raise NotImplementedError()
253
+
254
+ for i in range(1, order + 1):
255
+ R.append(torch.pow(rks, i - 1))
256
+ b.append(h_phi_k * factorial_i / B_h)
257
+ factorial_i *= i + 1
258
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
259
+
260
+ R = torch.stack(R)
261
+ b = torch.tensor(b, device=device)
262
+
263
+ if len(D1s) > 0:
264
+ D1s = torch.stack(D1s, dim=1) # (B, K)
265
+ # for order 2, we use a simplified version
266
+ if order == 2:
267
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
268
+ else:
269
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
270
+ else:
271
+ D1s = None
272
+
273
+ if self.predict_x0:
274
+ if D1s is not None:
275
+ pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
276
+ else:
277
+ pred_res = 0
278
+
279
+ val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1
280
+ if self.vel_p is None:
281
+ self.vel_p = val
282
+ else:
283
+ self.vel_p = (1-self.beta)*self.vel_p + self.beta * val
284
+ self.vel_p = val
285
+
286
+ x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1)
287
+ else:
288
+ raise NotImplementedError
289
+
290
+ x_t = x_t.to(x.dtype)
291
+ return x_t
292
+
293
+ def multistep_uni_c_bh_update(
294
+ self,
295
+ this_model_output: torch.FloatTensor,
296
+ this_timestep: int,
297
+ last_sample: torch.FloatTensor,
298
+ this_sample: torch.FloatTensor,
299
+ order: int,
300
+ ) -> torch.FloatTensor:
301
+
302
+ timestep_list = self.timestep_list
303
+ model_output_list = self.model_outputs
304
+
305
+ s0, t = timestep_list[-1], this_timestep
306
+ m0 = model_output_list[-1]
307
+ x = last_sample
308
+ x_t = this_sample
309
+ model_t = this_model_output
310
+
311
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
312
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
313
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
314
+
315
+ h = lambda_t - lambda_s0
316
+ device = this_sample.device
317
+
318
+ rks = []
319
+ D1s = []
320
+ for i in range(1, order):
321
+ si = timestep_list[-(i + 1)]
322
+ mi = model_output_list[-(i + 1)]
323
+ lambda_si = self.lambda_t[si]
324
+ rk = (lambda_si - lambda_s0) / h
325
+ rks.append(rk)
326
+ D1s.append((mi - m0) / rk)
327
+
328
+ rks.append(1.0)
329
+ rks = torch.tensor(rks, device=device)
330
+
331
+ R = []
332
+ b = []
333
+
334
+ hh = -h if self.predict_x0 else h
335
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
336
+ h_phi_k = h_phi_1 / hh - 1
337
+
338
+ factorial_i = 1
339
+
340
+ if self.config.solver_type == "bh1":
341
+ B_h = hh
342
+ elif self.config.solver_type == "bh2":
343
+ B_h = torch.expm1(hh)
344
+ else:
345
+ raise NotImplementedError()
346
+
347
+ for i in range(1, order + 1):
348
+ R.append(torch.pow(rks, i - 1))
349
+ b.append(h_phi_k * factorial_i / B_h)
350
+ factorial_i *= i + 1
351
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
352
+
353
+ R = torch.stack(R)
354
+ b = torch.tensor(b, device=device)
355
+
356
+ if len(D1s) > 0:
357
+ D1s = torch.stack(D1s, dim=1)
358
+ else:
359
+ D1s = None
360
+
361
+ # for order 1, we use a simplified version
362
+ if order == 1:
363
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
364
+ else:
365
+ rhos_c = torch.linalg.solve(R, b)
366
+
367
+ if self.predict_x0:
368
+ if D1s is not None:
369
+ corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
370
+ else:
371
+ corr_res = 0
372
+ D1_t = model_t - m0
373
+
374
+ val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1
375
+ if self.vel_c is None:
376
+ self.vel_c = val
377
+ else:
378
+ self.vel_c = (1-self.beta)*self.vel_c + self.beta * val
379
+
380
+ x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1)
381
+ else:
382
+ raise NotImplementedError
383
+
384
+ x_t = x_t.to(x.dtype)
385
+ return x_t
pipeline.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UniPCMultistepScheduler
5
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
6
+ from typing import Union, Optional, List, Callable, Dict, Any, Tuple
7
+ from momentum_scheduler import (
8
+ GHVBScheduler,
9
+ PLMSWithHBScheduler,
10
+ PLMSWithNTScheduler,
11
+ MomentumDPMSolverMultistepScheduler,
12
+ MomentumUniPCMultistepScheduler,
13
+ )
14
+
15
+ available_solvers = {
16
+ "GHVB": GHVBScheduler,
17
+ "PLMS_HB": PLMSWithHBScheduler,
18
+ "PLMS_NT": PLMSWithNTScheduler,
19
+ "DPM-Solver++": MomentumDPMSolverMultistepScheduler,
20
+ "UniPC": MomentumUniPCMultistepScheduler,
21
+ }
22
+
23
+ def get_momentum_number(order, beta):
24
+ out = order if beta == 1.0 else order - (1 - beta)
25
+ return out
26
+
27
+ def setup_scheduler(pipe, scheduler, momentum_type="Polyak's heavy ball", order=4.0, beta=1.0, original_config=None):
28
+ assert original_config is not None
29
+
30
+ if scheduler in ["DPM-Solver++", "UniPC"]:
31
+ if momentum_type in ["Nesterov"]:
32
+ raise NotImplementedError(f"{scheduler} w/ Nesterov is not implemented.")
33
+
34
+ pipe.scheduler = available_solvers[scheduler].from_config(original_config)
35
+ pipe.scheduler.initialize_momentum(beta=beta)
36
+
37
+ elif scheduler in ["PLMS"]:
38
+ momentum_number = get_momentum_number(order, beta)
39
+ method = "PLMS_HB" if momentum_type == "Polyak's heavy ball" else "PLMS_NT"
40
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(original_config)
41
+ pipe.init_scheduler(method=method, order=momentum_number)
42
+ pipe.clear_scheduler()
43
+
44
+ elif scheduler in ["GHVB"]:
45
+ momentum_number = get_momentum_number(order, beta)
46
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(original_config)
47
+ pipe.init_scheduler(method="GHVB", order=momentum_number)
48
+ pipe.clear_scheduler()
49
+
50
+ return pipe
51
+
52
+ class CustomPipeline(StableDiffusionPipeline):
53
+ def clear_scheduler(self):
54
+ self.scheduler_uncond.clear()
55
+ self.scheduler_text.clear()
56
+
57
+ def init_scheduler(self, method, order):
58
+ # equivalent to not applied numerical operator splitting since orders are the same
59
+ self.scheduler_uncond = available_solvers[method](self.scheduler, order)
60
+ self.scheduler_text = available_solvers[method](self.scheduler, order)
61
+
62
+ def get_noise(self, latents, prompt_embeds, guidance_scale, t, do_classifier_free_guidance):
63
+ # expand the latents if we are doing classifier free guidance
64
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
65
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
66
+
67
+ # predict the noise residual
68
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
69
+
70
+ if do_classifier_free_guidance:
71
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
72
+ grads_a = guidance_scale * (noise_pred_text - noise_pred_uncond)
73
+
74
+ return noise_pred_uncond, grads_a
75
+
76
+ def denoising_step(
77
+ self,
78
+ latents,
79
+ prompt_embeds,
80
+ guidance_scale,
81
+ t,
82
+ do_classifier_free_guidance,
83
+ method,
84
+ extra_step_kwargs,
85
+ ):
86
+ noise_pred_uncond, grads_a = self.get_noise(
87
+ latents, prompt_embeds, guidance_scale, t, do_classifier_free_guidance
88
+ )
89
+ if method in ["dpm", "unipc"]:
90
+ latents = self.scheduler.step(noise_pred_uncond + grads_a, t, latents, **extra_step_kwargs).prev_sample
91
+
92
+ elif method in ["hb", "ghvb", "nt"]:
93
+ latents = self.scheduler_uncond.step(noise_pred_uncond, t, latents, output_mode="scale")
94
+ latents = self.scheduler_text.step(grads_a, t, latents, output_mode='back')
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return latents
99
+
100
+ @torch.no_grad()
101
+ def __call__(
102
+ self,
103
+ prompt: Union[str, List[str]] = None,
104
+ height: Optional[int] = None,
105
+ width: Optional[int] = None,
106
+ num_inference_steps: int = 50,
107
+ guidance_scale: float = 7.5,
108
+ negative_prompt: Optional[Union[str, List[str]]] = None,
109
+ num_images_per_prompt: Optional[int] = 1,
110
+ eta: float = 0.0,
111
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112
+ latents: Optional[torch.FloatTensor] = None,
113
+ prompt_embeds: Optional[torch.FloatTensor] = None,
114
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
115
+ output_type: Optional[str] = "pil",
116
+ return_dict: bool = True,
117
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
118
+ callback_steps: int = 1,
119
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
120
+ method="ghvb",
121
+ ):
122
+ # 0. Default height and width to unet
123
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
124
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
125
+
126
+ # 1. Check inputs. Raise error if not correct
127
+ self.check_inputs(
128
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
129
+ )
130
+
131
+ # 2. Define call parameters
132
+ if prompt is not None and isinstance(prompt, str):
133
+ batch_size = 1
134
+ elif prompt is not None and isinstance(prompt, list):
135
+ batch_size = len(prompt)
136
+ else:
137
+ batch_size = prompt_embeds.shape[0]
138
+
139
+ device = self._execution_device
140
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
141
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
142
+ # corresponds to doing no classifier free guidance.
143
+ do_classifier_free_guidance = guidance_scale > 1.0
144
+
145
+ # 3. Encode input prompt
146
+ prompt_embeds = self._encode_prompt(
147
+ prompt,
148
+ device,
149
+ num_images_per_prompt,
150
+ do_classifier_free_guidance,
151
+ negative_prompt,
152
+ prompt_embeds=prompt_embeds,
153
+ negative_prompt_embeds=negative_prompt_embeds,
154
+ )
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+ # print(timesteps)
160
+
161
+ # 5. Prepare latent variables
162
+ num_channels_latents = self.unet.config.in_channels
163
+ latents = self.prepare_latents(
164
+ batch_size * num_images_per_prompt,
165
+ num_channels_latents,
166
+ height,
167
+ width,
168
+ prompt_embeds.dtype,
169
+ device,
170
+ generator,
171
+ latents,
172
+ )
173
+
174
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
175
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
176
+
177
+ # 7. Denoising loop
178
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
179
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
180
+ for i, t in enumerate(timesteps):
181
+ latents = self.denoising_step(
182
+ latents,
183
+ prompt_embeds,
184
+ guidance_scale,
185
+ t,
186
+ do_classifier_free_guidance,
187
+ method,
188
+ extra_step_kwargs,
189
+ )
190
+
191
+ # call the callback, if provided
192
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
193
+ progress_bar.update()
194
+ if callback is not None and i % callback_steps == 0:
195
+ callback(i, t, latents)
196
+
197
+ if output_type == "latent":
198
+ image = latents
199
+ has_nsfw_concept = None
200
+ elif output_type == "pil":
201
+ # 8. Post-processing
202
+ image = self.decode_latents(latents)
203
+
204
+ # 9. Run safety checker
205
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
206
+ has_nsfw_concept = False
207
+
208
+ # 10. Convert to PIL
209
+ image = self.numpy_to_pil(image)
210
+ else:
211
+ # 8. Post-processing
212
+ image = self.decode_latents(latents)
213
+
214
+ # 9. Run safety checker
215
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
216
+ has_nsfw_concept = False
217
+
218
+ # Offload last model to CPU
219
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
220
+ self.final_offload_hook.offload()
221
+
222
+ if not return_dict:
223
+ return (image, has_nsfw_concept)
224
+
225
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
226
+
227
+ def generate(self, params):
228
+ params["output_type"] = "latent"
229
+ ori_latents = self.__call__(**params)["images"]
230
+
231
+ with torch.no_grad():
232
+ latents = torch.clone(ori_latents)
233
+ image = self.decode_latents(latents)
234
+ image = self.numpy_to_pil(image)[0]
235
+
236
+ return image, ori_latents
requirements.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ aiofiles==23.1.0
3
+ aiohttp==3.8.4
4
+ aiosignal==1.3.1
5
+ altair==5.0.1
6
+ annotated-types==0.5.0
7
+ anyio==3.7.1
8
+ async-timeout==4.0.2
9
+ attrs==23.1.0
10
+ certifi==2023.5.7
11
+ charset-normalizer==3.2.0
12
+ click==8.1.5
13
+ cmake==3.26.4
14
+ contourpy==1.1.0
15
+ cycler==0.11.0
16
+ diffusers==0.15.0
17
+ exceptiongroup==1.1.2
18
+ fastapi==0.100.0
19
+ ffmpy==0.3.0
20
+ filelock==3.12.2
21
+ fonttools==4.41.0
22
+ frozenlist==1.4.0
23
+ fsspec==2023.6.0
24
+ gradio==3.36.1
25
+ gradio_client==0.2.9
26
+ h11==0.14.0
27
+ httpcore==0.17.3
28
+ httpx==0.24.1
29
+ huggingface-hub==0.16.4
30
+ idna==3.4
31
+ importlib-metadata==6.8.0
32
+ importlib-resources==6.0.0
33
+ Jinja2==3.1.2
34
+ jsonschema==4.18.3
35
+ jsonschema-specifications==2023.6.1
36
+ kiwisolver==1.4.4
37
+ linkify-it-py==2.0.2
38
+ lit==16.0.6
39
+ markdown-it-py==2.2.0
40
+ MarkupSafe==2.1.3
41
+ matplotlib==3.7.2
42
+ mdit-py-plugins==0.3.3
43
+ mdurl==0.1.2
44
+ mpmath==1.3.0
45
+ multidict==6.0.4
46
+ networkx==3.1
47
+ numpy==1.25.1
48
+ nvidia-cublas-cu11==11.10.3.66
49
+ nvidia-cuda-cupti-cu11==11.7.101
50
+ nvidia-cuda-nvrtc-cu11==11.7.99
51
+ nvidia-cuda-runtime-cu11==11.7.99
52
+ nvidia-cudnn-cu11==8.5.0.96
53
+ nvidia-cufft-cu11==10.9.0.58
54
+ nvidia-curand-cu11==10.2.10.91
55
+ nvidia-cusolver-cu11==11.4.0.1
56
+ nvidia-cusparse-cu11==11.7.4.91
57
+ nvidia-nccl-cu11==2.14.3
58
+ nvidia-nvtx-cu11==11.7.91
59
+ orjson==3.9.2
60
+ packaging==23.1
61
+ pandas==2.0.3
62
+ Pillow==10.0.0
63
+ psutil==5.9.5
64
+ pydantic==2.0.2
65
+ pydantic_core==2.1.2
66
+ pydub==0.25.1
67
+ Pygments==2.15.1
68
+ pyparsing==3.0.9
69
+ python-dateutil==2.8.2
70
+ python-multipart==0.0.6
71
+ pytz==2023.3
72
+ PyYAML==6.0
73
+ referencing==0.29.1
74
+ regex==2023.6.3
75
+ requests==2.31.0
76
+ rpds-py==0.8.10
77
+ semantic-version==2.10.0
78
+ six==1.16.0
79
+ sniffio==1.3.0
80
+ starlette==0.27.0
81
+ sympy==1.12
82
+ tokenizers==0.13.3
83
+ tomesd==0.1.3
84
+ toolz==0.12.0
85
+ torch==2.0.1
86
+ tqdm==4.65.0
87
+ transformers==4.28.1
88
+ triton==2.0.0
89
+ typing_extensions==4.7.1
90
+ tzdata==2023.3
91
+ uc-micro-py==1.0.2
92
+ urllib3==2.0.3
93
+ uvicorn==0.22.0
94
+ websockets==11.0.3
95
+ yarl==1.9.2
96
+ zipp==3.16.1