Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False): | |
| if drop_prob == 0.0 or not training: | |
| return x | |
| keep_prob = 1 - drop_prob | |
| shape = (x.shape[0],) + (1,) * ( | |
| x.ndim - 1 | |
| ) # work with diff dim tensors, not just 2D ConvNets | |
| random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
| if keep_prob > 0.0: | |
| random_tensor.div_(keep_prob) | |
| output = x * random_tensor | |
| return output | |
| class DropPath(nn.Module): | |
| def __init__(self, drop_prob=None): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| def forward(self, x): | |
| return drop_path(x, self.drop_prob, self.training) | |