iitolstykh commited on
Commit
11b25a1
·
verified ·
1 Parent(s): 816d4c8

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +42 -41
inferencer.py CHANGED
@@ -51,8 +51,9 @@ class InterleaveInferencer:
51
  new_token_ids=self.new_token_ids,
52
  )
53
 
54
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
55
- past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
 
56
  gen_context['kv_lens'] = kv_lens
57
  gen_context['ropes'] = ropes
58
  gen_context['past_key_values'] = past_key_values
@@ -77,8 +78,8 @@ class InterleaveInferencer:
77
  transforms=self.vae_transform,
78
  new_token_ids=self.new_token_ids,
79
  )
80
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
81
- past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
82
 
83
  if vit:
84
  ## update vit
@@ -89,8 +90,8 @@ class InterleaveInferencer:
89
  transforms=self.vit_transform,
90
  new_token_ids=self.new_token_ids,
91
  )
92
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
93
- past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
94
 
95
  gen_context['kv_lens'] = kv_lens
96
  gen_context['ropes'] = ropes
@@ -146,28 +147,28 @@ class InterleaveInferencer:
146
  image_sizes=[image_shape],
147
  )
148
 
149
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
150
- unpacked_latent = self.model.generate_image(
151
- past_key_values=past_key_values,
152
- cfg_text_past_key_values=cfg_text_past_key_values,
153
- cfg_img_past_key_values=cfg_img_past_key_values,
154
- num_timesteps=num_timesteps,
155
- cfg_text_scale=cfg_text_scale,
156
- cfg_img_scale=cfg_img_scale,
157
- cfg_interval=cfg_interval,
158
- cfg_renorm_min=cfg_renorm_min,
159
- cfg_renorm_type=cfg_renorm_type,
160
- timestep_shift=timestep_shift,
161
- **generation_input,
162
- cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
163
- cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
164
- cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
165
- cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
166
- cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
167
- cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
168
- cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
169
- cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
170
- )
171
 
172
  image = self.decode_image(unpacked_latent[0], image_shape)
173
  return image
@@ -193,19 +194,19 @@ class InterleaveInferencer:
193
  kv_lens = gen_context['kv_lens']
194
  ropes = gen_context['ropes']
195
 
196
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
197
- generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
198
- for unpacked_latent in self.model.generate_text(
199
- past_key_values=past_key_values,
200
- max_length=max_length,
201
- do_sample=do_sample,
202
- temperature=temperature,
203
- end_token_id=self.new_token_ids['eos_token_id'],
204
- **generation_input,
205
- ):
206
- output = self.tokenizer.decode(unpacked_latent)
207
- if output != "<|im_end|>":
208
- yield output
209
 
210
  @torch.no_grad()
211
  def interleave_inference(
 
51
  new_token_ids=self.new_token_ids,
52
  )
53
 
54
+ # with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
55
+ past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
56
+
57
  gen_context['kv_lens'] = kv_lens
58
  gen_context['ropes'] = ropes
59
  gen_context['past_key_values'] = past_key_values
 
78
  transforms=self.vae_transform,
79
  new_token_ids=self.new_token_ids,
80
  )
81
+ # with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
82
+ past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
83
 
84
  if vit:
85
  ## update vit
 
90
  transforms=self.vit_transform,
91
  new_token_ids=self.new_token_ids,
92
  )
93
+ # with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
94
+ past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
95
 
96
  gen_context['kv_lens'] = kv_lens
97
  gen_context['ropes'] = ropes
 
147
  image_sizes=[image_shape],
148
  )
149
 
150
+ # with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
151
+ unpacked_latent = self.model.generate_image(
152
+ past_key_values=past_key_values,
153
+ cfg_text_past_key_values=cfg_text_past_key_values,
154
+ cfg_img_past_key_values=cfg_img_past_key_values,
155
+ num_timesteps=num_timesteps,
156
+ cfg_text_scale=cfg_text_scale,
157
+ cfg_img_scale=cfg_img_scale,
158
+ cfg_interval=cfg_interval,
159
+ cfg_renorm_min=cfg_renorm_min,
160
+ cfg_renorm_type=cfg_renorm_type,
161
+ timestep_shift=timestep_shift,
162
+ **generation_input,
163
+ cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
164
+ cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
165
+ cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
166
+ cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
167
+ cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
168
+ cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
169
+ cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
170
+ cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
171
+ )
172
 
173
  image = self.decode_image(unpacked_latent[0], image_shape)
174
  return image
 
194
  kv_lens = gen_context['kv_lens']
195
  ropes = gen_context['ropes']
196
 
197
+ # with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
198
+ generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
199
+ for unpacked_latent in self.model.generate_text(
200
+ past_key_values=past_key_values,
201
+ max_length=max_length,
202
+ do_sample=do_sample,
203
+ temperature=temperature,
204
+ end_token_id=self.new_token_ids['eos_token_id'],
205
+ **generation_input,
206
+ ):
207
+ output = self.tokenizer.decode(unpacked_latent)
208
+ if output != "<|im_end|>":
209
+ yield output
210
 
211
  @torch.no_grad()
212
  def interleave_inference(