Spaces:
Running
Running
import abc | |
import logging | |
from typing import cast, Callable | |
from sklearn.cluster import MiniBatchKMeans | |
from feature_retrieval.index import NumpyArray | |
logger = logging.getLogger(__name__) | |
class IFeatureMatrixTransform: | |
"""Interface for transform encoded voice feature from (n_features,vector_dim) to (m_features,vector_dim)""" | |
def transform(self, matrix: NumpyArray) -> NumpyArray: | |
"""transform given feature matrix from (n_features,vector_dim) to (m_features,vector_dim)""" | |
raise NotImplementedError | |
class DummyFeatureTransform(IFeatureMatrixTransform): | |
"""do nothing""" | |
def transform(self, matrix: NumpyArray) -> NumpyArray: | |
return matrix | |
class MinibatchKmeansFeatureTransform(IFeatureMatrixTransform): | |
"""replaces number of examples with k-means centroids using minibatch algorythm""" | |
def __init__(self, n_clusters: int, n_parallel: int) -> None: | |
self._n_clusters = n_clusters | |
self._n_parallel = n_parallel | |
def _batch_size(self) -> int: | |
return self._n_parallel * 256 | |
def transform(self, matrix: NumpyArray) -> NumpyArray: | |
"""transform given feature matrix from (n_features,vector_dim) to (n_clusters,vector_dim)""" | |
cluster = MiniBatchKMeans( | |
n_clusters=self._n_clusters, | |
verbose=True, | |
batch_size=self._batch_size, | |
compute_labels=False, | |
init="k-means++", | |
) | |
return cast(NumpyArray, cluster.fit(matrix).cluster_centers_) | |
class OnConditionFeatureTransform(IFeatureMatrixTransform): | |
"""call given transform if condition is True else call otherwise transform""" | |
def __init__( | |
self, | |
condition: Callable[[NumpyArray], bool], | |
on_condition: IFeatureMatrixTransform, | |
otherwise: IFeatureMatrixTransform, | |
) -> None: | |
self._condition = condition | |
self._on_condition = on_condition | |
self._otherwise = otherwise | |
def transform(self, matrix: NumpyArray) -> NumpyArray: | |
if self._condition(matrix): | |
transform_name = self._on_condition.__class__.__name__ | |
logger.info(f"pass condition. Transform by rule {transform_name}") | |
return self._on_condition.transform(matrix) | |
transform_name = self._otherwise.__class__.__name__ | |
logger.info(f"condition is not passed. Transform by rule {transform_name}") | |
return self._otherwise.transform(matrix) | |