Spaces:
Running
on
Zero
Running
on
Zero
| ### | |
| # Author: Kai Li | |
| # Date: 2021-06-17 23:08:32 | |
| # LastEditors: Please set LastEditors | |
| # LastEditTime: 2022-05-26 18:06:22 | |
| ### | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import PyTorchModelHubMixin | |
| def _unsqueeze_to_3d(x): | |
| """Normalize shape of `x` to [batch, n_chan, time].""" | |
| if x.ndim == 1: | |
| return x.reshape(1, 1, -1) | |
| elif x.ndim == 2: | |
| return x.unsqueeze(1) | |
| else: | |
| return x | |
| def pad_to_appropriate_length(x, lcm): | |
| values_to_pad = int(x.shape[-1]) % lcm | |
| if values_to_pad: | |
| appropriate_shape = x.shape | |
| padded_x = torch.zeros( | |
| list(appropriate_shape[:-1]) | |
| + [appropriate_shape[-1] + lcm - values_to_pad], | |
| dtype=torch.float32, | |
| ).to(x.device) | |
| padded_x[..., : x.shape[-1]] = x | |
| return padded_x | |
| return x | |
| class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"): | |
| def __init__(self, sample_rate, in_chan=1): | |
| super().__init__() | |
| self._sample_rate = sample_rate | |
| self._in_chan = in_chan | |
| def forward(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def sample_rate(self,): | |
| return self._sample_rate | |
| def load_state_dict_in_audio(model, pretrained_dict): | |
| model_dict = model.state_dict() | |
| update_dict = {} | |
| for k, v in pretrained_dict.items(): | |
| if "audio_model" in k: | |
| update_dict[k[12:]] = v | |
| model_dict.update(update_dict) | |
| model.load_state_dict(model_dict) | |
| return model | |
| # @staticmethod | |
| # def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs): | |
| # from . import get | |
| # conf = torch.load( | |
| # pretrained_model_conf_or_path, map_location="cpu" | |
| # ) # Attempt to find the model and instantiate it. | |
| # model_class = get(conf["model_name"]) | |
| # # model_class = get("Conv_TasNet") | |
| # model = model_class(*args, **kwargs) | |
| # model.load_state_dict(conf["state_dict"]) | |
| # return model | |
| def serialize(self): | |
| import pytorch_lightning as pl # Not used in torch.hub | |
| model_conf = dict( | |
| model_name=self.__class__.__name__, | |
| state_dict=self.get_state_dict(), | |
| model_args=self.get_model_args(), | |
| ) | |
| # Additional infos | |
| infos = dict() | |
| infos["software_versions"] = dict( | |
| torch_version=torch.__version__, pytorch_lightning_version=pl.__version__, | |
| ) | |
| model_conf["infos"] = infos | |
| return model_conf | |
| def get_state_dict(self): | |
| """In case the state dict needs to be modified before sharing the model.""" | |
| return self.state_dict() | |
| def get_model_args(self): | |
| """Should return args to re-instantiate the class.""" | |
| raise NotImplementedError |