fix device
Browse files- pipeline.py +1 -1
pipeline.py
CHANGED
@@ -6819,7 +6819,7 @@ class InterleaveInferencer:
|
|
6819 |
past_key_values = gen_context['past_key_values']
|
6820 |
kv_lens = gen_context['kv_lens']
|
6821 |
ropes = gen_context['ropes']
|
6822 |
-
|
6823 |
generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
|
6824 |
generation_input = self._to_device(generation_input, device)
|
6825 |
unpacked_latent = self.model.generate_text(
|
|
|
6819 |
past_key_values = gen_context['past_key_values']
|
6820 |
kv_lens = gen_context['kv_lens']
|
6821 |
ropes = gen_context['ropes']
|
6822 |
+
device = next(self.model.parameters()).device
|
6823 |
generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
|
6824 |
generation_input = self._to_device(generation_input, device)
|
6825 |
unpacked_latent = self.model.generate_text(
|