File size: 1,927 Bytes
416db21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
model:
  type: "custom"
  name: "smollm2_transformer"
  tokenizer_name: "gpt2"
  vocab_size: 50257
  hidden_size: 256
  num_attention_heads: 4
  num_key_value_heads: 2
  num_hidden_layers: 6
  intermediate_size: 512
  hidden_act: "gelu"
  max_position_embeddings: 256
  initializer_range: 0.02
  rms_norm_eps: 1.0e-5
  use_cache: true
  pad_token_id: null

optimizer:
  type: "adamW"
  weight_decay: 0.01
  adam_beta1: 0.9
  adam_beta2: 0.95
  adam_eps: 1.0e-8
  torch_adam_is_fused: true
  clip_grad: 1.0
  accumulate_grad_in_fp32: true

scheduler:
  type: "one_cycle"
  learning_rate: 0.001
  warmup_steps: 50
  max_lr: 0.001
  pct_start: 0.02
  anneal_strategy: "cos"
  cycle_momentum: false
  div_factor: 25.0
  final_div_factor: 1000.0

training:
  output_dir: "./results"
  batch_size: 4
  micro_batch_size: 2
  gradient_accumulation_steps: 2
  sequence_length: 256
  learning_rate: 0.001
  max_steps: 5050  # Total steps (5000 + 50)
  first_phase_steps: 5000  # Initial training phase
  second_phase_steps: 50   # Fine-tuning phase
  sample_frequency: 100    # Sample every 100 steps in first phase
  second_phase_sample_frequency: 5  # Sample more frequently in second phase
  logging_dir: "./logs"
  logging_steps: 1
  save_steps: 100
  checkpoint_dir: "checkpoints"
  sample_prompt: "Explain what machine learning is:"
  max_generate_length: 50

hardware:
  precision: "16-mixed"
  accelerator: "gpu"
  devices: 1
  strategy: "auto"
  gradient_clip: 1.0
  cuda_memory_fraction: 0.9
  allow_tf32: true
  benchmark: true
  deterministic: false

data:
  datasets:
    - name: "wikitext"
      path: "wikitext"
      subset: "wikitext-103-raw-v1"
      split_ratio: 0.01  # Use only 1% of the dataset
      weight: 1.0
  loading:
    num_workers: 2
    batch_size: 16
    pin_memory: true
    prefetch_factor: 2
    persistent_workers: true