Spaces:
Running
Running
File size: 2,483 Bytes
9791162 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import abc
import logging
from typing import cast, Callable
from sklearn.cluster import MiniBatchKMeans
from feature_retrieval.index import NumpyArray
logger = logging.getLogger(__name__)
class IFeatureMatrixTransform:
"""Interface for transform encoded voice feature from (n_features,vector_dim) to (m_features,vector_dim)"""
@abc.abstractmethod
def transform(self, matrix: NumpyArray) -> NumpyArray:
"""transform given feature matrix from (n_features,vector_dim) to (m_features,vector_dim)"""
raise NotImplementedError
class DummyFeatureTransform(IFeatureMatrixTransform):
"""do nothing"""
def transform(self, matrix: NumpyArray) -> NumpyArray:
return matrix
class MinibatchKmeansFeatureTransform(IFeatureMatrixTransform):
"""replaces number of examples with k-means centroids using minibatch algorythm"""
def __init__(self, n_clusters: int, n_parallel: int) -> None:
self._n_clusters = n_clusters
self._n_parallel = n_parallel
@property
def _batch_size(self) -> int:
return self._n_parallel * 256
def transform(self, matrix: NumpyArray) -> NumpyArray:
"""transform given feature matrix from (n_features,vector_dim) to (n_clusters,vector_dim)"""
cluster = MiniBatchKMeans(
n_clusters=self._n_clusters,
verbose=True,
batch_size=self._batch_size,
compute_labels=False,
init="k-means++",
)
return cast(NumpyArray, cluster.fit(matrix).cluster_centers_)
class OnConditionFeatureTransform(IFeatureMatrixTransform):
"""call given transform if condition is True else call otherwise transform"""
def __init__(
self,
condition: Callable[[NumpyArray], bool],
on_condition: IFeatureMatrixTransform,
otherwise: IFeatureMatrixTransform,
) -> None:
self._condition = condition
self._on_condition = on_condition
self._otherwise = otherwise
def transform(self, matrix: NumpyArray) -> NumpyArray:
if self._condition(matrix):
transform_name = self._on_condition.__class__.__name__
logger.info(f"pass condition. Transform by rule {transform_name}")
return self._on_condition.transform(matrix)
transform_name = self._otherwise.__class__.__name__
logger.info(f"condition is not passed. Transform by rule {transform_name}")
return self._otherwise.transform(matrix)
|