Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List | |
| import torch | |
| from mmengine.model import BaseTTAModel | |
| from mmengine.structures import PixelData | |
| from mmseg.registry import MODELS | |
| from mmseg.utils import SampleList | |
| class SegTTAModel(BaseTTAModel): | |
| def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: | |
| """Merge predictions of enhanced data to one prediction. | |
| Args: | |
| data_samples_list (List[SampleList]): List of predictions | |
| of all enhanced data. | |
| Returns: | |
| SampleList: Merged prediction. | |
| """ | |
| predictions = [] | |
| for data_samples in data_samples_list: | |
| seg_logits = data_samples[0].seg_logits.data | |
| logits = torch.zeros(seg_logits.shape).to(seg_logits) | |
| for data_sample in data_samples: | |
| seg_logit = data_sample.seg_logits.data | |
| if self.module.out_channels > 1: | |
| logits += seg_logit.softmax(dim=0) | |
| else: | |
| logits += seg_logit.sigmoid() | |
| logits /= len(data_samples) | |
| if self.module.out_channels == 1: | |
| seg_pred = (logits > self.module.decode_head.threshold | |
| ).to(logits).squeeze(1) | |
| else: | |
| seg_pred = logits.argmax(dim=0) | |
| data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)}) | |
| if hasattr(data_samples[0], 'gt_sem_seg'): | |
| data_sample.set_data( | |
| {'gt_sem_seg': data_samples[0].gt_sem_seg}) | |
| data_sample.set_metainfo({'img_path': data_samples[0].img_path}) | |
| predictions.append(data_sample) | |
| return predictions | |