File size: 2,084 Bytes
6ce1a2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import PretrainedConfig

class transformerConfig(PretrainedConfig):
    model_type = "custom_transformer"

    def __init__(

        self,

        src_vocab_len   : int =184,

        tgt_vocab       : int =201,

        num_hiddens     : int =32,

        num_layers      : int =2,

        dropout         : int =0.1,

        batch_size      : int =64,

        num_steps       : int =10,

        lr              : int =0.005,

        num_epochs      : int =200,

        # device=d2l.try_gpu(),

        ffn_num_input   : int =32,

        ffn_num_hiddens : int =64,

        num_heads       : int =4,

        key_size        : int =32,

        query_size      : int =32,

        value_size      : int =32,

        norm_shape      : int =[32],



        # block_type="bottleneck",

        # layers: List[int] = [3, 4, 6, 3],

        # num_classes: int = 1000,

        # input_channels: int = 3,

        # cardinality: int = 1,

        # base_width: int = 64,

        # stem_width: int = 64,

        # stem_type: str = "",

        # avg_down: bool = False,



        **kwargs,

    ):
        # if block_type not in ["basic", "bottleneck"]:
        #     raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
        # if stem_type not in ["", "deep", "deep-tiered"]:
        #     raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
        self.src_vocab_len = src_vocab_len
        self.tgt_vocab = tgt_vocab
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.dropout = dropout
        self.batch_size = batch_size
        self.num_steps = num_steps
        self.lr = lr
        self.num_epochs = num_epochs
        self.ffn_num_input = ffn_num_input
        self.ffn_num_hiddens = ffn_num_hiddens
        self.num_heads = num_heads
        self.key_size = key_size
        self.query_size = query_size
        self.value_size = value_size
        self.norm_shape = norm_shape

        super().__init__(**kwargs)