|
import os |
|
import sys |
|
import logging |
|
import pathlib |
|
import typing |
|
import warnings |
|
|
|
SLURM_ENV = {k: v for k, v in os.environ.items() if 'SLURM' in k} |
|
if SLURM_ENV: |
|
print(f"SLURM_ENV: {SLURM_ENV}") |
|
project_path = pathlib.Path(__file__).parent.parent.parent |
|
sys.path.append(str(project_path)) |
|
|
|
import torch |
|
import torch.cuda |
|
|
|
from mllm.config import prepare_args |
|
from mllm.models import load_pretrained |
|
from mllm.utils import print_trainable_params |
|
from mllm.engine import prepare_trainer_collator |
|
from mllm.dataset import prepare_data, prepare_target_processor |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout), ], |
|
) |
|
|
|
|
|
def main(): |
|
cfg, training_args = prepare_args() |
|
model, preprocessor = load_pretrained(cfg.model_args, training_args) |
|
|
|
|
|
model, preprocessor = prepare_target_processor(model, preprocessor, cfg.model_args, training_args) |
|
print_trainable_params(model) |
|
|
|
|
|
collator_kwargs = cfg.data_args.collator_kwargs |
|
trainer_cls, data_collator_dict = prepare_trainer_collator(cfg.model_args, preprocessor, collator_kwargs) |
|
dataset, compute_metrics = prepare_data(cfg.data_args, cfg.model_args, training_args, preprocessor) |
|
|
|
|
|
trainer = trainer_cls( |
|
model=model, |
|
args=training_args, |
|
tokenizer=preprocessor['text'], |
|
train_dataset=dataset['train'] if training_args.do_train else None, |
|
eval_dataset=dataset['validation'] if training_args.do_eval else None, |
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None, |
|
**data_collator_dict, |
|
) |
|
|
|
|
|
if training_args.do_train: |
|
try: |
|
if (not training_args.overwrite_output_dir) and list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
train_result = trainer.train(resume_from_checkpoint=True) |
|
else: |
|
train_result = trainer.train() |
|
trainer.log_metrics("train", train_result.metrics) |
|
trainer.save_metrics("train", train_result.metrics) |
|
trainer.save_model() |
|
except RuntimeError as e: |
|
print(f"got RuntimeError: {e.args}") |
|
try: |
|
print(f"#### device {training_args.local_rank} summary ####\n{torch.cuda.memory_summary(training_args.local_rank)}") |
|
except Exception as inner_e: |
|
print(f"get Exception when show cuda summary: {inner_e.args}") |
|
raise e |
|
finally: |
|
trainer.save_state() |
|
trainer.plot_loss() |
|
|
|
|
|
try: |
|
output_dir = training_args.output_dir |
|
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
cfg.dump(os.path.join(output_dir, "cfg.py")) |
|
except Exception as e: |
|
warnings.warn(f'try to save cfg to output_dir, but get exception {e.args}') |
|
|
|
|
|
gen_kwargs = dict(cfg.data_args.gen_kwargs) |
|
gen_kwargs.setdefault('use_cache', True) |
|
|
|
|
|
if hasattr(cfg.model_args, 'gen_kwargs_set_pad_token_id') and cfg.model_args.gen_kwargs_set_pad_token_id: |
|
gen_kwargs['pad_token_id'] = preprocessor['text'].pad_token_id |
|
if hasattr(cfg.model_args, 'gen_kwargs_set_bos_token_id') and cfg.model_args.gen_kwargs_set_bos_token_id: |
|
gen_kwargs['bos_token_id'] = preprocessor['text'].bos_token_id |
|
if hasattr(cfg.model_args, 'gen_kwargs_set_eos_token_id') and cfg.model_args.gen_kwargs_set_eos_token_id: |
|
gen_kwargs['eos_token_id'] = preprocessor['text'].eos_token_id |
|
|
|
|
|
if training_args.do_eval: |
|
if hasattr(trainer, '_test_collator') and hasattr(trainer, '_eval_collator') \ |
|
and trainer._test_collator != trainer._eval_collator: |
|
warnings.warn('[WARNING!!!] use different collator for eval and test. but do_eval and ' |
|
'do_predict both use trainer.predict (i.e. only test_collator is used.)') |
|
eval_results = trainer.predict(dataset['validation'], metric_key_prefix="eval", **gen_kwargs) |
|
trainer.log_metrics("eval", eval_results.metrics) |
|
trainer.save_metrics("eval", eval_results.metrics) |
|
trainer.save_prediction(eval_results, file_key_prefix='eval') |
|
|
|
|
|
if training_args.do_predict: |
|
predict_results = trainer.predict(dataset['test'], metric_key_prefix="test", **gen_kwargs) |
|
trainer.log_metrics("test", predict_results.metrics) |
|
trainer.save_metrics("test", predict_results.metrics) |
|
trainer.save_prediction(predict_results, file_key_prefix='test') |
|
|
|
|
|
if training_args.do_multi_predict: |
|
old_compute_metrics = trainer.compute_metrics |
|
multitest = dataset['multitest'] |
|
multitest = typing.cast(dict, multitest) |
|
for _idx, (k, item) in enumerate(multitest.items()): |
|
print(f'processing multitest set {_idx}/{len(multitest)}: {k}') |
|
_ds = item['dataset'] |
|
_compute_metrics = item['compute_metric'] |
|
_prefix = f"multitest_{k}" |
|
|
|
trainer.compute_metrics = _compute_metrics |
|
_pred_results = trainer.predict(_ds, metric_key_prefix=_prefix, **gen_kwargs) |
|
trainer.log_metrics(_prefix, _pred_results.metrics) |
|
trainer.save_metrics(_prefix, _pred_results.metrics) |
|
trainer.save_prediction(_pred_results, file_key_prefix=_prefix) |
|
trainer.compute_metrics = old_compute_metrics |
|
|
|
|
|
|
|
def _mp_fn(index): |
|
|
|
main() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|