atsushieee's picture
Upload folder using huggingface_hub
9791162
raw
history blame
2.48 kB
import abc
import logging
from typing import cast, Callable
from sklearn.cluster import MiniBatchKMeans
from feature_retrieval.index import NumpyArray
logger = logging.getLogger(__name__)
class IFeatureMatrixTransform:
"""Interface for transform encoded voice feature from (n_features,vector_dim) to (m_features,vector_dim)"""
@abc.abstractmethod
def transform(self, matrix: NumpyArray) -> NumpyArray:
"""transform given feature matrix from (n_features,vector_dim) to (m_features,vector_dim)"""
raise NotImplementedError
class DummyFeatureTransform(IFeatureMatrixTransform):
"""do nothing"""
def transform(self, matrix: NumpyArray) -> NumpyArray:
return matrix
class MinibatchKmeansFeatureTransform(IFeatureMatrixTransform):
"""replaces number of examples with k-means centroids using minibatch algorythm"""
def __init__(self, n_clusters: int, n_parallel: int) -> None:
self._n_clusters = n_clusters
self._n_parallel = n_parallel
@property
def _batch_size(self) -> int:
return self._n_parallel * 256
def transform(self, matrix: NumpyArray) -> NumpyArray:
"""transform given feature matrix from (n_features,vector_dim) to (n_clusters,vector_dim)"""
cluster = MiniBatchKMeans(
n_clusters=self._n_clusters,
verbose=True,
batch_size=self._batch_size,
compute_labels=False,
init="k-means++",
)
return cast(NumpyArray, cluster.fit(matrix).cluster_centers_)
class OnConditionFeatureTransform(IFeatureMatrixTransform):
"""call given transform if condition is True else call otherwise transform"""
def __init__(
self,
condition: Callable[[NumpyArray], bool],
on_condition: IFeatureMatrixTransform,
otherwise: IFeatureMatrixTransform,
) -> None:
self._condition = condition
self._on_condition = on_condition
self._otherwise = otherwise
def transform(self, matrix: NumpyArray) -> NumpyArray:
if self._condition(matrix):
transform_name = self._on_condition.__class__.__name__
logger.info(f"pass condition. Transform by rule {transform_name}")
return self._on_condition.transform(matrix)
transform_name = self._otherwise.__class__.__name__
logger.info(f"condition is not passed. Transform by rule {transform_name}")
return self._otherwise.transform(matrix)