codescripts / kcenter.py
f541119578's picture
Upload folder using huggingface_hub
fdf190d verified
""" Origin code from https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py """
"""Returns points that minimizes the maximum distance of any point to a center.
Implements the k-Center-Greedy method in
Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for
Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017
Distance metric defaults to l2 distance. Features used to calculate distance
are either raw features or if a model has transform method then uses the output
of model.transform(X).
Can be extended to a robust k centers algorithm that ignores a certain number of
outlier datapoints. Resulting centers are solution to multiple integer program.
"""
import numpy as np
from sklearn.metrics import pairwise_distances
import torch
from tqdm import tqdm
import abc
import argparse
class SamplingMethod(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def __init__(self, X, y, seed, **kwargs):
self.X = X
self.y = y
self.seed = seed
def flatten_X(self):
shape = self.X.shape
flat_X = self.X
if len(shape) > 2:
flat_X = np.reshape(self.X, (shape[0], np.product(shape[1:])))
return flat_X
@abc.abstractmethod
def select_batch_(self):
return
def select_batch(self, **kwargs):
return self.select_batch_(**kwargs)
def to_dict(self):
return None
class kCenterGreedy(SamplingMethod):
def __init__(self, X, y, seed, metric="euclidean"):
self.X = X
self.y = y
self.flat_X = self.flatten_X()
self.name = "kcenter"
self.features = self.flat_X
self.metric = metric
self.min_distances = None
self.n_obs = self.X.shape[0]
self.already_selected = []
def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
"""Update min distances given cluster centers.
Args:
cluster_centers: indices of cluster centers
only_new: only calculate distance for newly selected points and update
min_distances.
rest_dist: whether to reset min_distances.
"""
if reset_dist:
self.min_distances = None
if only_new:
cluster_centers = [
d for d in cluster_centers if d not in self.already_selected
]
if cluster_centers:
# Update min_distances for all examples given new cluster center.
x = self.features[cluster_centers]
dist = pairwise_distances(self.features, x, metric=self.metric)
if self.min_distances is None:
self.min_distances = np.min(dist, axis=1).reshape(-1, 1)
else:
self.min_distances = np.minimum(self.min_distances, dist)
def select_batch_(self, model, already_selected, N, **kwargs):
"""
Diversity promoting active learning method that greedily forms a batch
to minimize the maximum distance to a cluster center among all unlabeled
datapoints.
Args:
model: model with scikit-like API with decision_function implemented
already_selected: index of datapoints already selected
N: batch size
Returns:
indices of points selected to minimize distance to cluster centers
"""
try:
# Assumes that the transform function takes in original data and not
# flattened data.
print("Getting transformed features...")
self.features = model.transform(self.X)
print("Calculating distances...")
self.update_distances(already_selected, only_new=False, reset_dist=True)
except:
print("Using flat_X as features.")
self.update_distances(already_selected, only_new=True, reset_dist=False)
new_batch = []
for _ in tqdm(range(N)):
if self.already_selected is None:
# Initialize centers with a randomly selected datapoint
ind = np.random.choice(np.arange(self.n_obs))
else:
ind = np.argmax(self.min_distances)
# New examples should not be in already selected since those points
# should have min_distance of zero to a cluster center.
assert ind not in already_selected
self.update_distances([ind], only_new=True, reset_dist=False)
new_batch.append(ind)
print(
"Maximum distance from cluster centers is %0.2f" % max(self.min_distances)
)
self.already_selected = already_selected
return new_batch
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--start', type=int)
parser.add_argument("--end", type=int)
args = parser.parse_args()
embeddings = torch.load("/home/aiscuser/fhw/embeddings/qwq_ins_embeddings.pt")
f = open("/home/aiscuser/fhw/data/qwq_python_selected.json", "r+")
fw = open(f"/home/aiscuser/fhw/data/qwq_python_diverse_{args.start}_{args.end}.json", "w+")
lines = f.readlines()[args.start:args.end]
selected_nums = 10000
nums = embeddings.shape[0]
kcg = kCenterGreedy(X=embeddings[args.start:args.end], y=None, seed=42)
batch = kcg.select_batch_(model=None, already_selected=[], N=selected_nums)
for idx in batch:
fw.write(lines[idx])