Kaixuanliu commited on
Commit
feaa3d5
·
1 Parent(s): 2563d19

add intel xpu platform support

Browse files

Signed-off-by: Liu, Kaixuan <[email protected]>

Files changed (3) hide show
  1. modeling_cogvlm.py +9 -3
  2. util.py +7 -1
  3. visual.py +2 -0
modeling_cogvlm.py CHANGED
@@ -8,6 +8,7 @@ from torch import nn
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
10
  from einops import rearrange
 
11
  from transformers import PreTrainedModel, PreTrainedTokenizer
12
  from transformers.utils.logging import get_logger
13
  from transformers.activations import ACT2FN
@@ -723,9 +724,14 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
723
  standardize_cache_format: bool = False,
724
  ) -> Dict[str, Any]:
725
  # update past_key_values
726
- cache_name, cache = self._extract_past_from_model_output(
727
- outputs, standardize_cache_format=standardize_cache_format
728
- )
 
 
 
 
 
729
  model_kwargs[cache_name] = cache
730
 
731
  if getattr(outputs, "state", None) is not None:
 
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
10
  from einops import rearrange
11
+ import transformers
12
  from transformers import PreTrainedModel, PreTrainedTokenizer
13
  from transformers.utils.logging import get_logger
14
  from transformers.activations import ACT2FN
 
724
  standardize_cache_format: bool = False,
725
  ) -> Dict[str, Any]:
726
  # update past_key_values
727
+ if transformers.__version__ >= "4.44.0":
728
+ cache_name, cache = self._extract_past_from_model_output(
729
+ outputs
730
+ )
731
+ else:
732
+ cache_name, cache = self._extract_past_from_model_output(
733
+ outputs, standardize_cache_format=standardize_cache_format
734
+ )
735
  model_kwargs[cache_name] = cache
736
 
737
  if getattr(outputs, "state", None) is not None:
util.py CHANGED
@@ -7,6 +7,10 @@ import torch.nn.functional as F
7
  import triton
8
  import triton.language as tl
9
 
 
 
 
 
10
 
11
  @triton.jit
12
  def rotary_kernel(
@@ -197,7 +201,9 @@ def apply_rotary(
197
 
198
  # Need this, otherwise Triton tries to launch from cuda:0 and we get
199
  # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
200
- with torch.cuda.device(x.device.index):
 
 
201
  rotary_kernel[grid](
202
  output, # data ptrs
203
  x,
 
7
  import triton
8
  import triton.language as tl
9
 
10
+ device_contexts = {
11
+ 'cuda': torch.cuda.device,
12
+ 'xpu': torch.xpu.device
13
+ }
14
 
15
  @triton.jit
16
  def rotary_kernel(
 
201
 
202
  # Need this, otherwise Triton tries to launch from cuda:0 and we get
203
  # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
204
+ device_type = x.device.type
205
+ assert device_type in device_contexts
206
+ with device_contexts[device_type](x.device.index):
207
  rotary_kernel[grid](
208
  output, # data ptrs
209
  x,
visual.py CHANGED
@@ -75,6 +75,8 @@ class Attention(nn.Module):
75
  out = out.transpose(2, 1)
76
  # breakpoint()
77
  # output = self.dense(out.reshape(B, L, -1))
 
 
78
  output = self.dense(out.view(B, L, -1))
79
  output = self.output_dropout(output)
80
  return output
 
75
  out = out.transpose(2, 1)
76
  # breakpoint()
77
  # output = self.dense(out.reshape(B, L, -1))
78
+ if not out.is_contiguous():
79
+ out = out.contiguous()
80
  output = self.dense(out.view(B, L, -1))
81
  output = self.output_dropout(output)
82
  return output