Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from typing import List | |
| class BaseConditioner(nn.Module): | |
| def __init__(self): | |
| super(BaseConditioner, self).__init__() | |
| def _impl_condition(self, y, metadata)->torch.Tensor: | |
| raise NotImplementedError() | |
| def _impl_uncondition(self, y, metadata)->torch.Tensor: | |
| raise NotImplementedError() | |
| def __call__(self, y, metadata:dict={}): | |
| condition = self._impl_condition(y, metadata) | |
| uncondition = self._impl_uncondition(y, metadata) | |
| if condition.dtype in [torch.float64, torch.float32, torch.float16]: | |
| condition = condition.to(torch.bfloat16) | |
| if uncondition.dtype in [torch.float64,torch.float32, torch.float16]: | |
| uncondition = uncondition.to(torch.bfloat16) | |
| return condition, uncondition | |
| class ComposeConditioner(BaseConditioner): | |
| def __init__(self, conditioners:List[BaseConditioner]): | |
| super().__init__() | |
| self.conditioners = conditioners | |
| def _impl_condition(self, y, metadata): | |
| condition = [] | |
| for conditioner in self.conditioners: | |
| condition.append(conditioner._impl_condition(y, metadata)) | |
| condition = torch.cat(condition, dim=1) | |
| return condition | |
| def _impl_uncondition(self, y, metadata): | |
| uncondition = [] | |
| for conditioner in self.conditioners: | |
| uncondition.append(conditioner._impl_uncondition(y, metadata)) | |
| uncondition = torch.cat(uncondition, dim=1) | |
| return uncondition |