File size: 2,958 Bytes
d6def08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import ast
from typing import List, Tuple

from config.config import Config


class TrainConfig(Config):

    RPN_PRE_NMS_TOP_N: int = 12000
    RPN_POST_NMS_TOP_N: int = 2000

    ANCHOR_SMOOTH_L1_LOSS_BETA: float = 1.0
    PROPOSAL_SMOOTH_L1_LOSS_BETA: float = 1.0

    BATCH_SIZE: int = 1
    LEARNING_RATE: float = 0.001
    MOMENTUM: float = 0.9
    WEIGHT_DECAY: float = 0.0005
    STEP_LR_SIZES: List[int] = [50000, 70000]
    STEP_LR_GAMMA: float = 0.1
    WARM_UP_FACTOR: float = 0.3333
    WARM_UP_NUM_ITERS: int = 500

    NUM_STEPS_TO_DISPLAY: int = 20
    NUM_STEPS_TO_SNAPSHOT: int = 10000
    NUM_STEPS_TO_FINISH: int = 90000

    @classmethod
    def setup(cls, image_min_side: float = None, image_max_side: float = None,
              anchor_ratios: List[Tuple[int, int]] = None, anchor_sizes: List[int] = None, pooler_mode: str = None,
              rpn_pre_nms_top_n: int = None, rpn_post_nms_top_n: int = None,
              anchor_smooth_l1_loss_beta: float = None, proposal_smooth_l1_loss_beta: float = None,
              batch_size: int = None, learning_rate: float = None, momentum: float = None, weight_decay: float = None,
              step_lr_sizes: List[int] = None, step_lr_gamma: float = None,
              warm_up_factor: float = None, warm_up_num_iters: int = None,
              num_steps_to_display: int = None, num_steps_to_snapshot: int = None, num_steps_to_finish: int = None):
        super().setup(image_min_side, image_max_side, anchor_ratios, anchor_sizes, pooler_mode)

        if rpn_pre_nms_top_n is not None:
            cls.RPN_PRE_NMS_TOP_N = rpn_pre_nms_top_n
        if rpn_post_nms_top_n is not None:
            cls.RPN_POST_NMS_TOP_N = rpn_post_nms_top_n

        if anchor_smooth_l1_loss_beta is not None:
            cls.ANCHOR_SMOOTH_L1_LOSS_BETA = anchor_smooth_l1_loss_beta
        if proposal_smooth_l1_loss_beta is not None:
            cls.PROPOSAL_SMOOTH_L1_LOSS_BETA = proposal_smooth_l1_loss_beta

        if batch_size is not None:
            cls.BATCH_SIZE = batch_size
        if learning_rate is not None:
            cls.LEARNING_RATE = learning_rate
        if momentum is not None:
            cls.MOMENTUM = momentum
        if weight_decay is not None:
            cls.WEIGHT_DECAY = weight_decay
        if step_lr_sizes is not None:
            cls.STEP_LR_SIZES = ast.literal_eval(step_lr_sizes)
        if step_lr_gamma is not None:
            cls.STEP_LR_GAMMA = step_lr_gamma
        if warm_up_factor is not None:
            cls.WARM_UP_FACTOR = warm_up_factor
        if warm_up_num_iters is not None:
            cls.WARM_UP_NUM_ITERS = warm_up_num_iters

        if num_steps_to_display is not None:
            cls.NUM_STEPS_TO_DISPLAY = num_steps_to_display
        if num_steps_to_snapshot is not None:
            cls.NUM_STEPS_TO_SNAPSHOT = num_steps_to_snapshot
        if num_steps_to_finish is not None:
            cls.NUM_STEPS_TO_FINISH = num_steps_to_finish