File size: 3,440 Bytes
c00ff2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from dataclasses import dataclass
from typing import Optional

import torch
from transformers import WhisperConfig
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput, Seq2SeqModelOutput


@dataclass
class Seq2SeqLMOutputLosses(Seq2SeqLMOutput):
    enc_loss: Optional[torch.FloatTensor] = None
    dec_loss: Optional[torch.FloatTensor] = None
    encoder_logits: Optional[torch.FloatTensor] = None


@dataclass
class BaseModelOutputLogit(BaseModelOutput):
    logits: Optional[torch.FloatTensor] = None


@dataclass
class Seq2SeqModelOutputLogit(Seq2SeqModelOutput):
    encoder_logits: Optional[torch.FloatTensor] = None


class DiCoWConfig(WhisperConfig):
    """This is a modified version of the `WhisperEncoder` model from the `transformers` library.
    The model has been modified to support CTC loss computation in the forward pass."""
    model_type = "DiCoW"

    def __init__(
            self,
            ctc_loss_reduction: str = "mean",
            final_dropout: float = 0.0,
            ctc_zero_infinity: bool = False,
            ctc_weight: float = 0.0,
            blank_token_id: Optional[int] = None,
            additional_layer: bool = False,
            additional_self_attention_layer: bool = False,
            sub_sample: bool = False,
            use_fddt: bool = True,
            fddt_is_diagonal: bool = True,
            fddt_bias_only: bool = False,
            fddt_use_silence: bool = True,
            fddt_use_target: bool = True,
            fddt_use_overlap: bool = True,
            fddt_use_non_target: bool = True,
            remove_timestamps_from_ctc: bool = False,
            apply_fddt_to_n_layers: int = -1,
            fddt_init: str = 'non-disturbing',  # random, non-disturbing, dispargement
            n_soft_prompts: int = 16,
            mt_num_speakers: int = 1,
            is_mt: bool = False,
            non_target_fddt_value: float = 0.0,
            use_initial_fddt: bool = False,
            scb_method: str = None,
            scb_layers: int = -1,
            contrastive_loss_weight: float = 0.0,
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.ctc_loss_reduction = ctc_loss_reduction
        self.final_dropout = final_dropout
        self.ctc_zero_infinity = ctc_zero_infinity
        self.ctc_weight = ctc_weight
        self.blank_token_id = blank_token_id
        self.additional_layer = additional_layer
        self.additional_self_attention_layer = additional_self_attention_layer
        self.sub_sample = sub_sample
        self.use_fddt = use_fddt
        self.fddt_is_diagonal = fddt_is_diagonal
        self.fddt_bias_only = fddt_bias_only
        self.fddt_use_silence = fddt_use_silence
        self.fddt_use_target = fddt_use_target
        self.fddt_use_overlap = fddt_use_overlap
        self.fddt_use_non_target = fddt_use_non_target
        self.remove_timestamps_from_ctc = remove_timestamps_from_ctc
        self.apply_fddt_to_n_layers = apply_fddt_to_n_layers
        self.fddt_init = fddt_init
        self.n_soft_prompts = n_soft_prompts
        self.mt_num_speakers = mt_num_speakers
        self.non_target_fddt_value = non_target_fddt_value
        self.use_initial_fddt = use_initial_fddt
        self.scb_method = scb_method
        self.scb_layers = scb_layers
        self.contrastive_loss_weight = contrastive_loss_weight
        self.is_mt = is_mt


_HIDDEN_STATES_START_POSITION = 2