|
""" Origin code from https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py """ |
|
|
|
"""Returns points that minimizes the maximum distance of any point to a center. |
|
|
|
Implements the k-Center-Greedy method in |
|
Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for |
|
Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017 |
|
|
|
Distance metric defaults to l2 distance. Features used to calculate distance |
|
are either raw features or if a model has transform method then uses the output |
|
of model.transform(X). |
|
|
|
Can be extended to a robust k centers algorithm that ignores a certain number of |
|
outlier datapoints. Resulting centers are solution to multiple integer program. |
|
""" |
|
|
|
|
|
import numpy as np |
|
from sklearn.metrics import pairwise_distances |
|
import torch |
|
from tqdm import tqdm |
|
import abc |
|
import argparse |
|
|
|
class SamplingMethod(object): |
|
__metaclass__ = abc.ABCMeta |
|
|
|
@abc.abstractmethod |
|
def __init__(self, X, y, seed, **kwargs): |
|
self.X = X |
|
self.y = y |
|
self.seed = seed |
|
|
|
def flatten_X(self): |
|
shape = self.X.shape |
|
flat_X = self.X |
|
if len(shape) > 2: |
|
flat_X = np.reshape(self.X, (shape[0], np.product(shape[1:]))) |
|
return flat_X |
|
|
|
@abc.abstractmethod |
|
def select_batch_(self): |
|
return |
|
|
|
def select_batch(self, **kwargs): |
|
return self.select_batch_(**kwargs) |
|
|
|
def to_dict(self): |
|
return None |
|
|
|
class kCenterGreedy(SamplingMethod): |
|
|
|
def __init__(self, X, y, seed, metric="euclidean"): |
|
self.X = X |
|
self.y = y |
|
self.flat_X = self.flatten_X() |
|
self.name = "kcenter" |
|
self.features = self.flat_X |
|
self.metric = metric |
|
self.min_distances = None |
|
self.n_obs = self.X.shape[0] |
|
self.already_selected = [] |
|
|
|
def update_distances(self, cluster_centers, only_new=True, reset_dist=False): |
|
"""Update min distances given cluster centers. |
|
|
|
Args: |
|
cluster_centers: indices of cluster centers |
|
only_new: only calculate distance for newly selected points and update |
|
min_distances. |
|
rest_dist: whether to reset min_distances. |
|
""" |
|
|
|
if reset_dist: |
|
self.min_distances = None |
|
if only_new: |
|
cluster_centers = [ |
|
d for d in cluster_centers if d not in self.already_selected |
|
] |
|
if cluster_centers: |
|
|
|
x = self.features[cluster_centers] |
|
dist = pairwise_distances(self.features, x, metric=self.metric) |
|
|
|
if self.min_distances is None: |
|
self.min_distances = np.min(dist, axis=1).reshape(-1, 1) |
|
else: |
|
self.min_distances = np.minimum(self.min_distances, dist) |
|
|
|
def select_batch_(self, model, already_selected, N, **kwargs): |
|
""" |
|
Diversity promoting active learning method that greedily forms a batch |
|
to minimize the maximum distance to a cluster center among all unlabeled |
|
datapoints. |
|
|
|
Args: |
|
model: model with scikit-like API with decision_function implemented |
|
already_selected: index of datapoints already selected |
|
N: batch size |
|
|
|
Returns: |
|
indices of points selected to minimize distance to cluster centers |
|
""" |
|
|
|
try: |
|
|
|
|
|
print("Getting transformed features...") |
|
self.features = model.transform(self.X) |
|
print("Calculating distances...") |
|
self.update_distances(already_selected, only_new=False, reset_dist=True) |
|
except: |
|
print("Using flat_X as features.") |
|
self.update_distances(already_selected, only_new=True, reset_dist=False) |
|
|
|
new_batch = [] |
|
|
|
for _ in tqdm(range(N)): |
|
if self.already_selected is None: |
|
|
|
ind = np.random.choice(np.arange(self.n_obs)) |
|
else: |
|
ind = np.argmax(self.min_distances) |
|
|
|
|
|
assert ind not in already_selected |
|
|
|
self.update_distances([ind], only_new=True, reset_dist=False) |
|
new_batch.append(ind) |
|
print( |
|
"Maximum distance from cluster centers is %0.2f" % max(self.min_distances) |
|
) |
|
|
|
self.already_selected = already_selected |
|
|
|
return new_batch |
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--start', type=int) |
|
parser.add_argument("--end", type=int) |
|
args = parser.parse_args() |
|
embeddings = torch.load("/home/aiscuser/fhw/embeddings/qwq_ins_embeddings.pt") |
|
f = open("/home/aiscuser/fhw/data/qwq_python_selected.json", "r+") |
|
fw = open(f"/home/aiscuser/fhw/data/qwq_python_diverse_{args.start}_{args.end}.json", "w+") |
|
lines = f.readlines()[args.start:args.end] |
|
selected_nums = 10000 |
|
nums = embeddings.shape[0] |
|
kcg = kCenterGreedy(X=embeddings[args.start:args.end], y=None, seed=42) |
|
batch = kcg.select_batch_(model=None, already_selected=[], N=selected_nums) |
|
for idx in batch: |
|
fw.write(lines[idx]) |
|
|