Spaces:
Runtime error
Runtime error
Update inferencer.py
Browse files- inferencer.py +42 -41
inferencer.py
CHANGED
@@ -51,8 +51,9 @@ class InterleaveInferencer:
|
|
51 |
new_token_ids=self.new_token_ids,
|
52 |
)
|
53 |
|
54 |
-
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
55 |
-
|
|
|
56 |
gen_context['kv_lens'] = kv_lens
|
57 |
gen_context['ropes'] = ropes
|
58 |
gen_context['past_key_values'] = past_key_values
|
@@ -77,8 +78,8 @@ class InterleaveInferencer:
|
|
77 |
transforms=self.vae_transform,
|
78 |
new_token_ids=self.new_token_ids,
|
79 |
)
|
80 |
-
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
81 |
-
|
82 |
|
83 |
if vit:
|
84 |
## update vit
|
@@ -89,8 +90,8 @@ class InterleaveInferencer:
|
|
89 |
transforms=self.vit_transform,
|
90 |
new_token_ids=self.new_token_ids,
|
91 |
)
|
92 |
-
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
93 |
-
|
94 |
|
95 |
gen_context['kv_lens'] = kv_lens
|
96 |
gen_context['ropes'] = ropes
|
@@ -146,28 +147,28 @@ class InterleaveInferencer:
|
|
146 |
image_sizes=[image_shape],
|
147 |
)
|
148 |
|
149 |
-
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
|
172 |
image = self.decode_image(unpacked_latent[0], image_shape)
|
173 |
return image
|
@@ -193,19 +194,19 @@ class InterleaveInferencer:
|
|
193 |
kv_lens = gen_context['kv_lens']
|
194 |
ropes = gen_context['ropes']
|
195 |
|
196 |
-
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
|
210 |
@torch.no_grad()
|
211 |
def interleave_inference(
|
|
|
51 |
new_token_ids=self.new_token_ids,
|
52 |
)
|
53 |
|
54 |
+
# with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
55 |
+
past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
|
56 |
+
|
57 |
gen_context['kv_lens'] = kv_lens
|
58 |
gen_context['ropes'] = ropes
|
59 |
gen_context['past_key_values'] = past_key_values
|
|
|
78 |
transforms=self.vae_transform,
|
79 |
new_token_ids=self.new_token_ids,
|
80 |
)
|
81 |
+
# with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
82 |
+
past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
|
83 |
|
84 |
if vit:
|
85 |
## update vit
|
|
|
90 |
transforms=self.vit_transform,
|
91 |
new_token_ids=self.new_token_ids,
|
92 |
)
|
93 |
+
# with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
94 |
+
past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
|
95 |
|
96 |
gen_context['kv_lens'] = kv_lens
|
97 |
gen_context['ropes'] = ropes
|
|
|
147 |
image_sizes=[image_shape],
|
148 |
)
|
149 |
|
150 |
+
# with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
151 |
+
unpacked_latent = self.model.generate_image(
|
152 |
+
past_key_values=past_key_values,
|
153 |
+
cfg_text_past_key_values=cfg_text_past_key_values,
|
154 |
+
cfg_img_past_key_values=cfg_img_past_key_values,
|
155 |
+
num_timesteps=num_timesteps,
|
156 |
+
cfg_text_scale=cfg_text_scale,
|
157 |
+
cfg_img_scale=cfg_img_scale,
|
158 |
+
cfg_interval=cfg_interval,
|
159 |
+
cfg_renorm_min=cfg_renorm_min,
|
160 |
+
cfg_renorm_type=cfg_renorm_type,
|
161 |
+
timestep_shift=timestep_shift,
|
162 |
+
**generation_input,
|
163 |
+
cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
|
164 |
+
cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
|
165 |
+
cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
|
166 |
+
cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
|
167 |
+
cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
|
168 |
+
cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
|
169 |
+
cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
|
170 |
+
cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
|
171 |
+
)
|
172 |
|
173 |
image = self.decode_image(unpacked_latent[0], image_shape)
|
174 |
return image
|
|
|
194 |
kv_lens = gen_context['kv_lens']
|
195 |
ropes = gen_context['ropes']
|
196 |
|
197 |
+
# with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
198 |
+
generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
|
199 |
+
for unpacked_latent in self.model.generate_text(
|
200 |
+
past_key_values=past_key_values,
|
201 |
+
max_length=max_length,
|
202 |
+
do_sample=do_sample,
|
203 |
+
temperature=temperature,
|
204 |
+
end_token_id=self.new_token_ids['eos_token_id'],
|
205 |
+
**generation_input,
|
206 |
+
):
|
207 |
+
output = self.tokenizer.decode(unpacked_latent)
|
208 |
+
if output != "<|im_end|>":
|
209 |
+
yield output
|
210 |
|
211 |
@torch.no_grad()
|
212 |
def interleave_inference(
|