File size: 2,001 Bytes
b4942cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from dataclasses import dataclass, field
from typing import Optional

import transformers


@dataclass
class ModelArguments:
    llm_name_or_path: Optional[str] = field(default=None)
    visual_tokenizer_type: str = field(default=None)
    visual_vocab_size: int = field(default=8192)
    visual_drop_cls_token: bool = field(default=False)
    visual_tokenize_function: str = field(default='softmax')
    visual_tau: float = field(default=1.0)
    visual_depths: Optional[str] = field(default=None)
    visual_hidden_stride: int = field(default=1)
    multimodal_max_length: int = field(default=2048)
    conversation_formatter_class: str = field(default=None)
    pad_token_id: Optional[int] = field(default=None)
    llm_attn_implementation: Optional[str] = field(default=None)
    disable_tie_weight: bool = field(default=False)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    dataset_names: Optional[str] = field(default=None)  # a|b|c
    dataset_info: Optional[str] = field(default='dataset_info_v1_6')
    ovis_pretrained_path: Optional[str] = field(default=None)
    visual_tokenizer_pretrained_path: Optional[str] = field(default=None)
    caption_template: Optional[str] = field(default=None)
    stage: Optional[int] = field(default=None)
    train_modules: Optional[str] = field(default=None)
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    visual_max_tau: float = field(default=5.0)
    visual_min_tau: float = field(default=0.05)
    save_safetensors: bool = field(default=True)
    monitor_step: int = field(default=100)
    vte_re_init: bool = field(default=False)
    text_max_length: int = field(default=1024)
    max_partitions: str = field(default="9|1|1")

    def __post_init__(self):
        if self.gradient_checkpointing:
            self.gradient_checkpointing_kwargs = {"use_reentrant": False}
        if self.stage < 3:
            self.save_safetensors = False
        super().__post_init__()