RoyYang0714 commited on
Commit
a41a682
·
1 Parent(s): 048731a

fix: Disable cuda for ms attention.

Browse files
Files changed (1) hide show
  1. vis4d/op/layer/ms_deform_attn.py +16 -16
vis4d/op/layer/ms_deform_attn.py CHANGED
@@ -542,22 +542,22 @@ class MultiScaleDeformableAttention(nn.Module):
542
  )
543
 
544
  if torch.cuda.is_available() and value.is_cuda:
545
- if VIS4D_CUDA_OPS_AVAILABLE:
546
- output = MSDeformAttentionFunction.apply(
547
- value,
548
- input_spatial_shapes,
549
- input_level_start_index,
550
- sampling_locations,
551
- attention_weights,
552
- self.im2col_step,
553
- )
554
- else:
555
- output = ms_deformable_attention_cpu(
556
- value.cpu(),
557
- input_spatial_shapes.cpu(),
558
- sampling_locations.cpu(),
559
- attention_weights.cpu(),
560
- ).cuda()
561
  else:
562
  output = ms_deformable_attention_cpu(
563
  value,
 
542
  )
543
 
544
  if torch.cuda.is_available() and value.is_cuda:
545
+ # if VIS4D_CUDA_OPS_AVAILABLE:
546
+ # output = MSDeformAttentionFunction.apply(
547
+ # value,
548
+ # input_spatial_shapes,
549
+ # input_level_start_index,
550
+ # sampling_locations,
551
+ # attention_weights,
552
+ # self.im2col_step,
553
+ # )
554
+ # else:
555
+ output = ms_deformable_attention_cpu(
556
+ value.cpu(),
557
+ input_spatial_shapes.cpu(),
558
+ sampling_locations.cpu(),
559
+ attention_weights.cpu(),
560
+ ).cuda()
561
  else:
562
  output = ms_deformable_attention_cpu(
563
  value,