File size: 5,722 Bytes
4f8ad24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import logging
from enum import IntEnum
from typing import Iterator, Optional, List, Tuple

import numpy as np
from hbutils.string import plural_word
from hbutils.testing import disable_output
from imgutils.metrics import ccip_extract_feature, ccip_default_threshold, ccip_clustering, ccip_batch_differences

from .base import BaseAction
from ..model import ImageItem


class CCIPStatus(IntEnum):
    INIT = 0x1
    APPROACH = 0x2
    EVAL = 0x3
    INIT_WITH_SOURCE = 0x4


class CCIPAction(BaseAction):
    def __init__(self, init_source=None, *, min_val_count: int = 15, step: int = 5,
                 ratio_threshold: float = 0.6, min_clu_dump_ratio: float = 0.3, cmp_threshold: float = 0.5,
                 eps: Optional[float] = None, min_samples: Optional[int] = None,
                 model='ccip-caformer-24-randaug-pruned', threshold: Optional[float] = None):
        self.init_source = init_source

        self.min_val_count = min_val_count
        self.step = step
        self.ratio_threshold = ratio_threshold
        self.min_clu_dump_ratio = min_clu_dump_ratio
        self.cmp_threshold = cmp_threshold
        self.eps, self.min_samples = eps, min_samples
        self.model = model
        self.threshold = threshold or ccip_default_threshold(self.model)

        self.items = []
        self.item_released = []
        self.feats = []
        if self.init_source is not None:
            self.status = CCIPStatus.INIT_WITH_SOURCE
        else:
            self.status = CCIPStatus.INIT

    def _extract_feature(self, item: ImageItem):
        if 'ccip_feature' in item.meta:
            return item.meta['ccip_feature']
        else:
            return ccip_extract_feature(item.image, model=self.model)

    def _try_cluster(self) -> bool:
        with disable_output():
            clu_ids = ccip_clustering(self.feats, method='optics', model=self.model,
                                      eps=self.eps, min_samples=self.min_samples)
        clu_counts = {}
        for id_ in clu_ids:
            if id_ != -1:
                clu_counts[id_] = clu_counts.get(id_, 0) + 1

        clu_total = sum(clu_counts.values()) if clu_counts else 0
        chosen_id = None
        for id_, count in clu_counts.items():
            if count >= clu_total * self.ratio_threshold:
                chosen_id = id_
                break

        if chosen_id is not None:
            feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id]
            clu_dump_ratio = np.array([
                self._compare_to_exists(feat, base_set=feats)
                for feat in feats
            ]).astype(float).mean()

            if clu_dump_ratio >= self.min_clu_dump_ratio:
                self.items = [item for i, item in enumerate(self.items) if clu_ids[i] == chosen_id]
                self.item_released = [False] * len(self.items)
                self.feats = [feat for i, feat in enumerate(self.feats) if clu_ids[i] == chosen_id]
                return True
            else:
                return False
        else:
            return False

    def _compare_to_exists(self, feat, base_set=None) -> Tuple[bool, List[int]]:
        diffs = ccip_batch_differences([feat, *(base_set or self.feats)], model=self.model)[0, 1:]
        matches = diffs <= self.threshold
        return matches.astype(float).mean() >= self.cmp_threshold

    def _dump_items(self) -> Iterator[ImageItem]:
        for i in range(len(self.items)):
            if not self.item_released[i]:
                if self._compare_to_exists(self.feats[i]):
                    self.item_released[i] = True
                    yield self.items[i]

    def _eval_iter(self, item: ImageItem) -> Iterator[ImageItem]:
        feat = self._extract_feature(item)
        if self._compare_to_exists(feat):
            self.feats.append(feat)
            yield item

            if (len(self.feats) - len(self.items)) % self.step == 0:
                yield from self._dump_items()

    def iter(self, item: ImageItem) -> Iterator[ImageItem]:
        if self.status == CCIPStatus.INIT_WITH_SOURCE:
            cnt = 0
            logging.info('Existing anchor detected.')
            for item_ in self.init_source:
                self.feats.append(self._extract_feature(item_))
                yield item_
                cnt += 1
            logging.info(f'{plural_word(cnt, "items")} loaded from anchor.')

            self.status = CCIPStatus.EVAL
            yield from self._eval_iter(item)

        elif self.status == CCIPStatus.INIT:
            self.items.append(item)
            self.feats.append(self._extract_feature(item))

            if len(self.items) >= self.min_val_count:
                if self._try_cluster():
                    self.status = CCIPStatus.EVAL
                    yield from self._dump_items()
                else:
                    self.status = CCIPStatus.APPROACH

        elif self.status == CCIPStatus.APPROACH:
            self.items.append(item)
            self.feats.append(self._extract_feature(item))

            if (len(self.items) - self.min_val_count) % self.step == 0:
                if self._try_cluster():
                    self.status = CCIPStatus.EVAL
                    yield from self._dump_items()

        elif self.status == CCIPStatus.EVAL:
            yield from self._eval_iter(item)

        else:
            raise ValueError(f'Unknown status for {self.__class__.__name__} - {self.status!r}.')

    def reset(self):
        self.items.clear()
        self.item_released.clear()
        self.feats.clear()
        if self.init_source:
            self.status = CCIPStatus.INIT_WITH_SOURCE
        else:
            self.status = CCIPStatus.INIT