File size: 1,579 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Union

from mmcv.transforms import BaseTransform

PIPELINE_TYPE = List[Union[dict, BaseTransform]]


def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
    """Returns the index of the transform in a pipeline.

    Args:
        pipeline (List[dict] | List[BaseTransform]): The transforms list.
        target (str): The target transform class name.

    Returns:
        int: The transform index. Returns -1 if not found.
    """
    for i, transform in enumerate(pipeline):
        if isinstance(transform, dict):
            if isinstance(transform['type'], type):
                if transform['type'].__name__ == target:
                    return i
            else:
                if transform['type'] == target:
                    return i
        else:
            if transform.__class__.__name__ == target:
                return i

    return -1


def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
    """Remove the target transform type from the pipeline.

    Args:
        pipeline (List[dict] | List[BaseTransform]): The transforms list.
        target (str): The target transform class name.
        inplace (bool): Whether to modify the pipeline inplace.

    Returns:
        The modified transform.
    """
    idx = get_transform_idx(pipeline, target)
    if not inplace:
        pipeline = copy.deepcopy(pipeline)
    while idx >= 0:
        pipeline.pop(idx)
        idx = get_transform_idx(pipeline, target)

    return pipeline