DiCoW_v3_2 / FDDT.py
Lakoc's picture
Upload DiCoWForConditionalGeneration
c00ff2c verified
from typing import Optional
import torch
from torch import nn
from .layers import CustomDiagonalLinear, CustomLinear
class FDDT(nn.Module):
def __init__(self, d_model, non_target_rate=0.01, is_diagonal=False, bias_only=False, use_silence=True,
use_target=True, use_overlap=True, use_non_target=True, use_interaction=False,
scb_module: Optional[nn.Module] = None, ):
super().__init__()
if use_target:
self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
d_model,
bias=True,
init_eye_val=1.0))
if use_non_target:
self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
d_model, d_model, bias=True, init_eye_val=non_target_rate))
if use_overlap:
self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
d_model,
bias=True,
init_eye_val=1.0))
if use_silence:
self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
d_model, d_model, bias=True, init_eye_val=non_target_rate))
if use_interaction:
self.scb = scb_module
self.use_silence = use_silence
self.use_target = use_target
self.use_overlap = use_overlap
self.use_non_target = use_non_target
self.use_interaction = use_interaction
self.bias_only = bias_only
@staticmethod
def mask_out_non_interaction_signal(hidden_states, mask):
mask = torch.round(mask).bool()
masked_hidden_states = hidden_states * mask
return masked_hidden_states
def forward(self, hidden_states, stno_mask):
stno_mask = stno_mask.to(hidden_states.device)[..., None]
if self.bias_only:
if self.use_silence:
hidden_states += stno_mask[:, 0, ...] * self.silence_linear
if self.use_target:
hidden_states += stno_mask[:, 1, ...] * self.target_linear
if self.use_non_target:
hidden_states += stno_mask[:, 2, ...] * self.non_target_linear
if self.use_overlap:
hidden_states += stno_mask[:, 3, ...] * self.overlap_linear
# if self.use_interaction:
# hidden_states += stno_mask[:, 4, ...] * self.scb
else:
orig_hidden_states = hidden_states
hidden_states = (self.silence_linear(
orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \
(self.target_linear(
orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \
(self.non_target_linear(
orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2,
:] + \
(self.overlap_linear(
orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :]
# (self.scb(orig_hidden_states) * stno_mask[:, 4,:] if self.use_interaction else (
# 0 if stno_mask.size(
# 1) == 4 else orig_hidden_states * stno_mask[:, 4,
# :]))
if self.use_interaction:
hidden_states = self.scb(hidden_states)
return hidden_states