Spaces:
Runtime error
Runtime error
File size: 1,579 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Union
from mmcv.transforms import BaseTransform
PIPELINE_TYPE = List[Union[dict, BaseTransform]]
def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
"""Returns the index of the transform in a pipeline.
Args:
pipeline (List[dict] | List[BaseTransform]): The transforms list.
target (str): The target transform class name.
Returns:
int: The transform index. Returns -1 if not found.
"""
for i, transform in enumerate(pipeline):
if isinstance(transform, dict):
if isinstance(transform['type'], type):
if transform['type'].__name__ == target:
return i
else:
if transform['type'] == target:
return i
else:
if transform.__class__.__name__ == target:
return i
return -1
def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
"""Remove the target transform type from the pipeline.
Args:
pipeline (List[dict] | List[BaseTransform]): The transforms list.
target (str): The target transform class name.
inplace (bool): Whether to modify the pipeline inplace.
Returns:
The modified transform.
"""
idx = get_transform_idx(pipeline, target)
if not inplace:
pipeline = copy.deepcopy(pipeline)
while idx >= 0:
pipeline.pop(idx)
idx = get_transform_idx(pipeline, target)
return pipeline
|