import torch def b(a): a += 3 print(a) a[a==4] += 3 print(a) return a a = torch.ones(10, 2).cuda() print(a) b(a) print(a)