File size: 4,737 Bytes
c00ff2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Optional

import torch
from torch import nn

from .layers import CustomDiagonalLinear, CustomLinear


class FDDT(nn.Module):
    def __init__(self, d_model, non_target_rate=0.01, is_diagonal=False, bias_only=False, use_silence=True,
                 use_target=True, use_overlap=True, use_non_target=True, use_interaction=False,
                 scb_module: Optional[nn.Module] = None, ):
        super().__init__()
        if use_target:
            self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
                CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
                                                                                                            d_model,
                                                                                                            bias=True,
                                                                                                            init_eye_val=1.0))
        if use_non_target:
            self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
                CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
                    d_model, d_model, bias=True, init_eye_val=non_target_rate))
        if use_overlap:
            self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
                CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model,
                                                                                                            d_model,
                                                                                                            bias=True,
                                                                                                            init_eye_val=1.0))
        if use_silence:
            self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else (
                CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear(
                    d_model, d_model, bias=True, init_eye_val=non_target_rate))

        if use_interaction:
            self.scb = scb_module

        self.use_silence = use_silence
        self.use_target = use_target
        self.use_overlap = use_overlap
        self.use_non_target = use_non_target
        self.use_interaction = use_interaction
        self.bias_only = bias_only

    @staticmethod
    def mask_out_non_interaction_signal(hidden_states, mask):
        mask = torch.round(mask).bool()
        masked_hidden_states = hidden_states * mask
        return masked_hidden_states

    def forward(self, hidden_states, stno_mask):
        stno_mask = stno_mask.to(hidden_states.device)[..., None]
        if self.bias_only:
            if self.use_silence:
                hidden_states += stno_mask[:, 0, ...] * self.silence_linear
            if self.use_target:
                hidden_states += stno_mask[:, 1, ...] * self.target_linear
            if self.use_non_target:
                hidden_states += stno_mask[:, 2, ...] * self.non_target_linear
            if self.use_overlap:
                hidden_states += stno_mask[:, 3, ...] * self.overlap_linear
            # if self.use_interaction:
            #     hidden_states += stno_mask[:, 4, ...] * self.scb
        else:
            orig_hidden_states = hidden_states
            hidden_states = (self.silence_linear(
                orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \
                            (self.target_linear(
                                orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \
                            (self.non_target_linear(
                                orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2,
                                                                                                      :] + \
                            (self.overlap_linear(
                                orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :]
                            # (self.scb(orig_hidden_states) * stno_mask[:, 4,:] if self.use_interaction else (
                            #     0 if stno_mask.size(
                            #         1) == 4 else orig_hidden_states * stno_mask[:, 4,
                            #                                           :]))
        if self.use_interaction:
            hidden_states = self.scb(hidden_states)
        return hidden_states