atsushieee's picture
Upload folder using huggingface_hub
9791162
raw
history blame
1.46 kB
import abc
import logging
import torch
from feature_retrieval import FaissRetrievableFeatureIndex
logger = logging.getLogger(__name__)
class IRetrieval(abc.ABC):
@abc.abstractmethod
def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abc.abstractmethod
def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class DummyRetrieval(IRetrieval):
def retriv_whisper(self, vec: torch.FloatTensor) -> torch.FloatTensor:
logger.debug("start dummy retriv whisper")
return vec.clone().to(torch.device("cpu"))
def retriv_hubert(self, vec: torch.FloatTensor) -> torch.FloatTensor:
logger.debug("start dummy retriv hubert")
return vec.clone().to(torch.device("cpu"))
class FaissIndexRetrieval(IRetrieval):
def __init__(self, hubert_index: FaissRetrievableFeatureIndex, whisper_index: FaissRetrievableFeatureIndex) -> None:
self._hubert_index = hubert_index
self._whisper_index = whisper_index
def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor:
logger.debug("start retriv whisper")
np_vec = self._whisper_index.retriv(vec.numpy())
return torch.from_numpy(np_vec)
def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor:
logger.debug("start retriv hubert")
np_vec = self._hubert_index.retriv(vec.numpy())
return torch.from_numpy(np_vec)