Spaces:
Runtime error
Runtime error
File size: 4,930 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.utils import is_str
if hasattr(torch, 'tensor_split'):
tensor_split = torch.tensor_split
else:
# A simple implementation of `tensor_split`.
def tensor_split(input: torch.Tensor, indices: list):
outs = []
for start, end in zip([0] + indices, indices + [input.size(0)]):
outs.append(input[start:end])
return outs
LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int]
SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence]
def format_label(value: LABEL_TYPE) -> torch.Tensor:
"""Convert various python types to label-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
Returns:
:obj:`torch.Tensor`: The foramtted label tensor.
"""
# Handle single number
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
value = int(value.item())
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def format_score(value: SCORE_TYPE) -> torch.Tensor:
"""Convert various python types to score-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
Returns:
:obj:`torch.Tensor`: The foramtted score tensor.
"""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).float()
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).float()
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def cat_batch_labels(elements: List[torch.Tensor]):
"""Concat a batch of label tensor to one tensor.
Args:
elements (List[tensor]): A batch of labels.
Returns:
Tuple[torch.Tensor, List[int]]: The first item is the concated label
tensor, and the second item is the split indices of every sample.
"""
labels = []
splits = [0]
for element in elements:
labels.append(element)
splits.append(splits[-1] + element.size(0))
batch_label = torch.cat(labels)
return batch_label, splits[1:-1]
def batch_label_to_onehot(batch_label, split_indices, num_classes):
"""Convert a concated label tensor to onehot format.
Args:
batch_label (torch.Tensor): A concated label tensor from multiple
samples.
split_indices (List[int]): The split indices of every sample.
num_classes (int): The number of classes.
Returns:
torch.Tensor: The onehot format label tensor.
Examples:
>>> import torch
>>> from mmpretrain.structures import batch_label_to_onehot
>>> # Assume a concated label from 3 samples.
>>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1]
>>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1])
>>> split_indices = [2, 5]
>>> batch_label_to_onehot(batch_label, split_indices, num_classes=5)
tensor([[1, 1, 0, 0, 0],
[1, 0, 1, 0, 1],
[0, 1, 0, 1, 0]])
"""
sparse_onehot_list = F.one_hot(batch_label, num_classes)
onehot_list = [
sparse_onehot.sum(0)
for sparse_onehot in tensor_split(sparse_onehot_list, split_indices)
]
return torch.stack(onehot_list)
def label_to_onehot(label: LABEL_TYPE, num_classes: int):
"""Convert a label to onehot format tensor.
Args:
label (LABEL_TYPE): Label value.
num_classes (int): The number of classes.
Returns:
torch.Tensor: The onehot format label tensor.
Examples:
>>> import torch
>>> from mmpretrain.structures import label_to_onehot
>>> # Single-label
>>> label_to_onehot(1, num_classes=5)
tensor([0, 1, 0, 0, 0])
>>> # Multi-label
>>> label_to_onehot([0, 2, 3], num_classes=5)
tensor([1, 0, 1, 1, 0])
"""
label = format_label(label)
sparse_onehot = F.one_hot(label, num_classes)
return sparse_onehot.sum(0)
|