Spaces:
Running
Running
import abc | |
import logging | |
import torch | |
from feature_retrieval import FaissRetrievableFeatureIndex | |
logger = logging.getLogger(__name__) | |
class IRetrieval(abc.ABC): | |
def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor: | |
raise NotImplementedError | |
def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor: | |
raise NotImplementedError | |
class DummyRetrieval(IRetrieval): | |
def retriv_whisper(self, vec: torch.FloatTensor) -> torch.FloatTensor: | |
logger.debug("start dummy retriv whisper") | |
return vec.clone().to(torch.device("cpu")) | |
def retriv_hubert(self, vec: torch.FloatTensor) -> torch.FloatTensor: | |
logger.debug("start dummy retriv hubert") | |
return vec.clone().to(torch.device("cpu")) | |
class FaissIndexRetrieval(IRetrieval): | |
def __init__(self, hubert_index: FaissRetrievableFeatureIndex, whisper_index: FaissRetrievableFeatureIndex) -> None: | |
self._hubert_index = hubert_index | |
self._whisper_index = whisper_index | |
def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor: | |
logger.debug("start retriv whisper") | |
np_vec = self._whisper_index.retriv(vec.numpy()) | |
return torch.from_numpy(np_vec) | |
def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor: | |
logger.debug("start retriv hubert") | |
np_vec = self._hubert_index.retriv(vec.numpy()) | |
return torch.from_numpy(np_vec) | |