para-lost commited on
Commit
c387723
·
1 Parent(s): 37b1444

fix device

Browse files
Files changed (1) hide show
  1. pipeline.py +1 -1
pipeline.py CHANGED
@@ -6819,7 +6819,7 @@ class InterleaveInferencer:
6819
  past_key_values = gen_context['past_key_values']
6820
  kv_lens = gen_context['kv_lens']
6821
  ropes = gen_context['ropes']
6822
-
6823
  generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
6824
  generation_input = self._to_device(generation_input, device)
6825
  unpacked_latent = self.model.generate_text(
 
6819
  past_key_values = gen_context['past_key_values']
6820
  kv_lens = gen_context['kv_lens']
6821
  ropes = gen_context['ropes']
6822
+ device = next(self.model.parameters()).device
6823
  generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
6824
  generation_input = self._to_device(generation_input, device)
6825
  unpacked_latent = self.model.generate_text(