Spaces:
Running
Running
VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/patches
/dependencies
/diffusers
/control.py
| from contextlib import contextmanager | |
| from typing import List, Union | |
| import torch | |
| from diffusers.hooks import HookRegistry, ModelHook | |
| _CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK" | |
| class ControlChannelConcatenateHook(ModelHook): | |
| def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]): | |
| self.input_names = input_names | |
| self.inputs = inputs | |
| self.dims = dims | |
| def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims): | |
| original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name] | |
| control_tensor = torch.cat([original_tensor, input_tensor], dim=dim) | |
| if isinstance(input_name, int): | |
| args[input_name] = control_tensor | |
| else: | |
| kwargs[input_name] = control_tensor | |
| return args, kwargs | |
| def control_channel_concat( | |
| module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int] | |
| ): | |
| registry = HookRegistry.check_if_exists_or_initialize(module) | |
| hook = ControlChannelConcatenateHook(input_names, inputs, dims) | |
| registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK) | |
| yield | |
| registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False) | |