Saving train state of step 5
Browse files- distil-whisper/events.out.tfevents.1715173957.server02.1931183.0 +3 -0
- distil-whisper/events.out.tfevents.1715174400.server02.1934277.0 +3 -0
- distil-whisper/events.out.tfevents.1715174461.server02.1934867.0 +3 -0
- distil-whisper/events.out.tfevents.1715174772.server02.1937015.0 +3 -0
- distil-whisper/events.out.tfevents.1715174837.server02.1937715.0 +3 -0
- distil-whisper/events.out.tfevents.1715174907.server02.1938409.0 +3 -0
- distil-whisper/events.out.tfevents.1715183755.server02.1990428.0 +3 -0
- run_distillation.py +28 -6
distil-whisper/events.out.tfevents.1715173957.server02.1931183.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a0baaf5c51a7181547abed27026c4d96f6e67c79735b12da3ff8ab92b5ad8d34
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715174400.server02.1934277.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d55b57f53ba824e31a3e70e10be7b81fffbc7136cf863b03721bb34f860c36e1
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715174461.server02.1934867.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a7cbaa5531dd3b4ce37f609a4f68d565dbcdc01c5b996dc862d7c531efbe032
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715174772.server02.1937015.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17d80187904c3f499e62daa25bd6a595704610866a8ca14f7464f436cae37096
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715174837.server02.1937715.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba268dc4a701175857d00571487bd02a56b657427e5c6c19bbc4d5d828ee2479
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715174907.server02.1938409.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b99ee4c468275a20eec36de30b0e0bd3326c2605623203ae6724843aea1dfc78
|
3 |
+
size 88
|
distil-whisper/events.out.tfevents.1715183755.server02.1990428.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89067edc26fa465bd32a850c8f10ead9195d7e28ea74ffab40e7e4c485c2e403
|
3 |
+
size 88
|
run_distillation.py
CHANGED
@@ -1312,7 +1312,8 @@ def main():
|
|
1312 |
num_epochs = int(training_args.num_train_epochs)
|
1313 |
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
|
1314 |
total_train_steps = steps_per_epoch * num_epochs
|
1315 |
-
|
|
|
1316 |
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
1317 |
total_train_steps = int(training_args.max_steps)
|
1318 |
if not data_args.streaming:
|
@@ -1427,14 +1428,14 @@ def main():
|
|
1427 |
student_model.train()
|
1428 |
teacher_model.eval()
|
1429 |
|
1430 |
-
student_outputs = student_model(**batch)
|
1431 |
with torch.no_grad():
|
1432 |
-
if share_hidden_states:
|
1433 |
# if the student and teacher share the same frozen encoder then we don't have to recompute the
|
1434 |
# encoder hidden-states for the teacher model, we can just re-use from the student
|
1435 |
encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
|
1436 |
teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
|
1437 |
-
else:
|
1438 |
# do the full forward pass for the teacher model (encoder + decoder)
|
1439 |
teacher_outputs = teacher_model(**batch)
|
1440 |
|
@@ -1546,8 +1547,24 @@ def main():
|
|
1546 |
print(f" vectorized_datasets['eval'] : {vectorized_datasets['eval']}")
|
1547 |
print(f" vectorized_datasets['train'] : {vectorized_datasets['train']}")
|
1548 |
|
1549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1550 |
|
|
|
|
|
|
|
1551 |
for epoch in range(epochs_trained, num_epochs):
|
1552 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
1553 |
train_dataloader = DataLoader(
|
@@ -1570,10 +1587,13 @@ def main():
|
|
1570 |
|
1571 |
for batch in train_dataloader:
|
1572 |
with accelerator.accumulate(student_model):
|
|
|
1573 |
loss, train_metric = train_step(batch, temperature=training_args.temperature)
|
|
|
1574 |
accelerator.backward(loss)
|
1575 |
if accelerator.sync_gradients:
|
1576 |
accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
|
|
|
1577 |
optimizer.step()
|
1578 |
lr_scheduler.step()
|
1579 |
optimizer.zero_grad()
|
@@ -1582,7 +1602,9 @@ def main():
|
|
1582 |
if accelerator.sync_gradients:
|
1583 |
steps_trained_progress_bar.update(1)
|
1584 |
cur_step += 1
|
|
|
1585 |
|
|
|
1586 |
if cur_step % training_args.logging_steps == 0:
|
1587 |
steps_trained_progress_bar.write(
|
1588 |
f"Step... ({cur_step} / {total_train_steps} | Loss:"
|
@@ -1733,5 +1755,5 @@ def main():
|
|
1733 |
if __name__ == "__main__":
|
1734 |
main()
|
1735 |
'''
|
1736 |
-
accelerate launch --mixed_precision=bf16 run_distillation.py --model_name_or_path "./distil-large-v3-init" --teacher_model_name_or_path "openai/whisper-large-v3" --train_dataset_name "mozilla-foundation/common_voice_15_0" --train_dataset_config_name "de" --train_split_name "train" --text_column_name "sentence" --eval_dataset_name "mozilla-foundation/common_voice_15_0" --eval_dataset_config_name "de" --eval_split_name "validation" --eval_text_column_name "sentence" --eval_steps 5 --save_steps 50 --warmup_steps 500 --learning_rate 1e-4 --lr_scheduler_type "linear" --logging_steps 25 --save_total_limit 1 --max_steps 5 --per_device_train_batch_size 4 --per_device_eval_batch_size 2 --dataloader_num_workers 2 --preprocessing_num_workers 2 --ddp_timeout 7200 --dtype "bfloat16" --output_dir "./" --use_pseudo_labels "false" --condition_on_prev_probability "0.0" --do_train --do_eval --gradient_checkpointing --overwrite_output_dir --predict_with_generate --freeze_encoder --streaming --push_to_hub --language de
|
1737 |
'''
|
|
|
1312 |
num_epochs = int(training_args.num_train_epochs)
|
1313 |
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
|
1314 |
total_train_steps = steps_per_epoch * num_epochs
|
1315 |
+
|
1316 |
+
elif training_args.max_steps > 0: #since we use data streaming , this condition is satisfied
|
1317 |
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
1318 |
total_train_steps = int(training_args.max_steps)
|
1319 |
if not data_args.streaming:
|
|
|
1428 |
student_model.train()
|
1429 |
teacher_model.eval()
|
1430 |
|
1431 |
+
student_outputs = student_model(**batch) # __call__ is overidden for forward function , note : student_model and teacher model both are whisperforconditionalgeneration object
|
1432 |
with torch.no_grad():
|
1433 |
+
if share_hidden_states:
|
1434 |
# if the student and teacher share the same frozen encoder then we don't have to recompute the
|
1435 |
# encoder hidden-states for the teacher model, we can just re-use from the student
|
1436 |
encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
|
1437 |
teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
|
1438 |
+
else:
|
1439 |
# do the full forward pass for the teacher model (encoder + decoder)
|
1440 |
teacher_outputs = teacher_model(**batch)
|
1441 |
|
|
|
1547 |
print(f" vectorized_datasets['eval'] : {vectorized_datasets['eval']}")
|
1548 |
print(f" vectorized_datasets['train'] : {vectorized_datasets['train']}")
|
1549 |
|
1550 |
+
#see example of validation dataloader
|
1551 |
+
# validation_dataloader = DataLoader(
|
1552 |
+
# vectorized_datasets[eval_split],
|
1553 |
+
# collate_fn=data_collator,
|
1554 |
+
# batch_size=per_device_eval_batch_size,
|
1555 |
+
# drop_last=False,
|
1556 |
+
# num_workers=dataloader_num_workers,
|
1557 |
+
# prefetch_factor=prefetch_factor,
|
1558 |
+
# pin_memory=training_args.dataloader_pin_memory,
|
1559 |
+
# )
|
1560 |
+
|
1561 |
+
# for batch in validation_dataloader:
|
1562 |
+
# print(batch['input_features'].shape)
|
1563 |
+
|
1564 |
|
1565 |
+
print(f" student_model : {type(student_model)}")
|
1566 |
+
|
1567 |
+
|
1568 |
for epoch in range(epochs_trained, num_epochs):
|
1569 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
|
1570 |
train_dataloader = DataLoader(
|
|
|
1587 |
|
1588 |
for batch in train_dataloader:
|
1589 |
with accelerator.accumulate(student_model):
|
1590 |
+
#they are updated their parameters every batch
|
1591 |
loss, train_metric = train_step(batch, temperature=training_args.temperature)
|
1592 |
+
#backward pass with loss
|
1593 |
accelerator.backward(loss)
|
1594 |
if accelerator.sync_gradients:
|
1595 |
accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
|
1596 |
+
#update after forward method
|
1597 |
optimizer.step()
|
1598 |
lr_scheduler.step()
|
1599 |
optimizer.zero_grad()
|
|
|
1602 |
if accelerator.sync_gradients:
|
1603 |
steps_trained_progress_bar.update(1)
|
1604 |
cur_step += 1
|
1605 |
+
|
1606 |
|
1607 |
+
#logging timing
|
1608 |
if cur_step % training_args.logging_steps == 0:
|
1609 |
steps_trained_progress_bar.write(
|
1610 |
f"Step... ({cur_step} / {total_train_steps} | Loss:"
|
|
|
1755 |
if __name__ == "__main__":
|
1756 |
main()
|
1757 |
'''
|
1758 |
+
accelerate launch --mixed_precision=bf16 run_distillation.py --model_name_or_path "./distil-large-v3-init" --teacher_model_name_or_path "openai/whisper-large-v3" --train_dataset_name "mozilla-foundation/common_voice_15_0" --train_dataset_config_name "de" --train_split_name "train" --text_column_name "sentence" --eval_dataset_name "mozilla-foundation/common_voice_15_0" --eval_dataset_config_name "de" --eval_split_name "validation" --eval_text_column_name "sentence" --eval_steps 5 --save_steps 50 --warmup_steps 500 --learning_rate 1e-4 --lr_scheduler_type "linear" --logging_steps 25 --save_total_limit 1 --max_steps 5 --per_device_train_batch_size 4 --per_device_eval_batch_size 2 --dataloader_num_workers 2 --preprocessing_num_workers 2 --ddp_timeout 7200 --dtype "bfloat16" --output_dir "./" --use_pseudo_labels "false" --condition_on_prev_probability "0.0" --do_train --do_eval --gradient_checkpointing --overwrite_output_dir --predict_with_generate --freeze_encoder --streaming --push_to_hub --language de --max_eval_samples 5
|
1759 |
'''
|