|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
def _unsqueeze_to_3d(x): |
|
|
"""Normalize shape of `x` to [batch, n_chan, time].""" |
|
|
if x.ndim == 1: |
|
|
return x.reshape(1, 1, -1) |
|
|
elif x.ndim == 2: |
|
|
return x.unsqueeze(1) |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
def pad_to_appropriate_length(x, lcm): |
|
|
values_to_pad = int(x.shape[-1]) % lcm |
|
|
if values_to_pad: |
|
|
appropriate_shape = x.shape |
|
|
padded_x = torch.zeros( |
|
|
list(appropriate_shape[:-1]) |
|
|
+ [appropriate_shape[-1] + lcm - values_to_pad], |
|
|
dtype=torch.float32, |
|
|
).to(x.device) |
|
|
padded_x[..., : x.shape[-1]] = x |
|
|
return padded_x |
|
|
return x |
|
|
|
|
|
|
|
|
class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"): |
|
|
def __init__(self, sample_rate, in_chan=1): |
|
|
super().__init__() |
|
|
self._sample_rate = sample_rate |
|
|
self._in_chan = in_chan |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
def sample_rate(self,): |
|
|
return self._sample_rate |
|
|
|
|
|
@staticmethod |
|
|
def load_state_dict_in_audio(model, pretrained_dict): |
|
|
model_dict = model.state_dict() |
|
|
update_dict = {} |
|
|
for k, v in pretrained_dict.items(): |
|
|
if "audio_model" in k: |
|
|
update_dict[k[12:]] = v |
|
|
model_dict.update(update_dict) |
|
|
model.load_state_dict(model_dict) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def serialize(self): |
|
|
import pytorch_lightning as pl |
|
|
|
|
|
model_conf = dict( |
|
|
model_name=self.__class__.__name__, |
|
|
state_dict=self.get_state_dict(), |
|
|
model_args=self.get_model_args(), |
|
|
) |
|
|
|
|
|
infos = dict() |
|
|
infos["software_versions"] = dict( |
|
|
torch_version=torch.__version__, pytorch_lightning_version=pl.__version__, |
|
|
) |
|
|
model_conf["infos"] = infos |
|
|
return model_conf |
|
|
|
|
|
def get_state_dict(self): |
|
|
"""In case the state dict needs to be modified before sharing the model.""" |
|
|
return self.state_dict() |
|
|
|
|
|
def get_model_args(self): |
|
|
"""Should return args to re-instantiate the class.""" |
|
|
raise NotImplementedError |