File size: 2,544 Bytes
4f8ad24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Dict, Iterator, Literal

import numpy as np
from imgutils.metrics import lpips_difference, lpips_extract_feature

from .base import BaseAction
from ..model import ImageItem


class FeatureBucket:
    def __init__(self, threshold: float = 0.45, capacity: int = 500, rtol=1.e-5, atol=1.e-8):
        self.threshold = threshold
        self.rtol, self.atol = rtol, atol
        self.features = []
        self.ratios = np.array([], dtype=float)
        self.capacity = capacity

    def check_duplicate(self, feat, ratio: float):
        for id_ in np.where(np.isclose(self.ratios, ratio, rtol=self.rtol, atol=self.atol))[0]:
            exist_feat = self.features[id_.item()]
            if lpips_difference(exist_feat, feat) <= self.threshold:
                return True

        return False

    def add(self, feat, ratio: float):
        self.features.append(feat)
        self.ratios = np.append(self.ratios, ratio)
        if len(self.features) >= self.capacity * 2:
            self.features = self.features[-self.capacity:]
            self.ratios = self.ratios[-self.capacity:]


FilterSimilarModeTyping = Literal['all', 'group']


class FilterSimilarAction(BaseAction):
    def __init__(self, mode: FilterSimilarModeTyping = 'all', threshold: float = 0.45,
                 capacity: int = 500, rtol=5.e-2, atol=2.e-2):
        self.mode = mode
        self.threshold, self.rtol, self.atol = threshold, rtol, atol
        self.capacity = capacity
        self.buckets: Dict[str, FeatureBucket] = {}
        self.global_bucket = FeatureBucket(threshold, self.capacity, rtol, atol)

    def _get_bin(self, group_id):
        if self.mode == 'all':
            return self.global_bucket
        elif self.mode == 'group':
            if group_id not in self.buckets:
                self.buckets[group_id] = FeatureBucket(self.threshold, self.capacity, self.rtol, self.atol)

            return self.buckets[group_id]
        else:
            raise ValueError(f'Unknown mode for filter similar action - {self.mode!r}.')

    def iter(self, item: ImageItem) -> Iterator[ImageItem]:
        image = item.image
        ratio = image.height * 1.0 / image.width
        feat = lpips_extract_feature(image)
        bucket = self._get_bin(item.meta.get('group_id'))

        if not bucket.check_duplicate(feat, ratio):
            bucket.add(feat, ratio)
            yield item

    def reset(self):
        self.buckets.clear()
        self.global_bucket = FeatureBucket(self.threshold, self.capacity, self.rtol, self.atol)