File size: 2,211 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) OpenMMLab. All rights reserved.

from typing import Any, Optional, Sequence, Union

import numpy as np
import torch
from mmengine.utils import is_seq_of
from torch import Tensor


def to_numpy(x: Union[Tensor, Sequence[Tensor]],
             return_device: bool = False,
             unzip: bool = False) -> Union[np.ndarray, tuple]:
    """Convert torch tensor to numpy.ndarray.

    Args:
        x (Tensor | Sequence[Tensor]): A single tensor or a sequence of
            tensors
        return_device (bool): Whether return the tensor device. Defaults to
            ``False``
        unzip (bool): Whether unzip the input sequence. Defaults to ``False``

    Returns:
        np.ndarray | tuple: If ``return_device`` is ``True``, return a tuple
        of converted numpy array(s) and the device indicator; otherwise only
        return the numpy array(s)
    """

    if isinstance(x, Tensor):
        arrays = x.detach().cpu().numpy()
        device = x.device
    elif is_seq_of(x, Tensor):
        if unzip:
            # convert (A, B) -> [(A[0], B[0]), (A[1], B[1]), ...]
            arrays = [
                tuple(to_numpy(_x[None, :]) for _x in _each)
                for _each in zip(*x)
            ]
        else:
            arrays = [to_numpy(_x) for _x in x]

        device = x[0].device

    else:
        raise ValueError(f'Invalid input type {type(x)}')

    if return_device:
        return arrays, device
    else:
        return arrays


def to_tensor(x: Union[np.ndarray, Sequence[np.ndarray]],
              device: Optional[Any] = None) -> Union[Tensor, Sequence[Tensor]]:
    """Convert numpy.ndarray to torch tensor.

    Args:
        x (np.ndarray | Sequence[np.ndarray]): A single np.ndarray or a
            sequence of tensors
        tensor (Any, optional): The device indicator. Defaults to ``None``

    Returns:
        tuple:
        - Tensor | Sequence[Tensor]: The converted Tensor or Tensor sequence
    """
    if isinstance(x, np.ndarray):
        return torch.tensor(x, device=device)
    elif is_seq_of(x, np.ndarray):
        return [to_tensor(_x, device=device) for _x in x]
    else:
        raise ValueError(f'Invalid input type {type(x)}')