RxnIM / mllm /pipeline /finetune.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
import os
import sys
import logging
import pathlib
import typing
import warnings
SLURM_ENV = {k: v for k, v in os.environ.items() if 'SLURM' in k}
if SLURM_ENV:
print(f"SLURM_ENV: {SLURM_ENV}")
project_path = pathlib.Path(__file__).parent.parent.parent
sys.path.append(str(project_path))
import torch
import torch.cuda
from mllm.config import prepare_args
from mllm.models import load_pretrained
from mllm.utils import print_trainable_params
from mllm.engine import prepare_trainer_collator
from mllm.dataset import prepare_data, prepare_target_processor
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout), ],
)
def main():
cfg, training_args = prepare_args()
model, preprocessor = load_pretrained(cfg.model_args, training_args)
# Some ugly codes to inject target_processor into preprocessor.
# maybe effect model. (e.g. add special token; resize embedding)
model, preprocessor = prepare_target_processor(model, preprocessor, cfg.model_args, training_args)
print_trainable_params(model)
# Prepare data_collator
collator_kwargs = cfg.data_args.collator_kwargs
trainer_cls, data_collator_dict = prepare_trainer_collator(cfg.model_args, preprocessor, collator_kwargs)
dataset, compute_metrics = prepare_data(cfg.data_args, cfg.model_args, training_args, preprocessor)
# Initialize Trainer
trainer = trainer_cls(
model=model,
args=training_args,
tokenizer=preprocessor['text'],
train_dataset=dataset['train'] if training_args.do_train else None,
eval_dataset=dataset['validation'] if training_args.do_eval else None,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
**data_collator_dict,
)
# Training
if training_args.do_train:
try:
if (not training_args.overwrite_output_dir) and list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
train_result = trainer.train(resume_from_checkpoint=True)
else:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics) # noqa
trainer.save_metrics("train", train_result.metrics) # noqa
trainer.save_model()
except RuntimeError as e:
print(f"got RuntimeError: {e.args}")
try:
print(f"#### device {training_args.local_rank} summary ####\n{torch.cuda.memory_summary(training_args.local_rank)}")
except Exception as inner_e:
print(f"get Exception when show cuda summary: {inner_e.args}")
raise e
finally:
trainer.save_state() # noqa
trainer.plot_loss()
# save cfg to output_dir
try:
output_dir = training_args.output_dir
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
cfg.dump(os.path.join(output_dir, "cfg.py"))
except Exception as e:
warnings.warn(f'try to save cfg to output_dir, but get exception {e.args}')
# Keyword arguments for `model.generate`
gen_kwargs = dict(cfg.data_args.gen_kwargs)
gen_kwargs.setdefault('use_cache', True)
# important for use model.generate in batch mode. some model config with wrong special_token_id
# (e.g. shikra generationConfig set pad_token_id to -1)
if hasattr(cfg.model_args, 'gen_kwargs_set_pad_token_id') and cfg.model_args.gen_kwargs_set_pad_token_id:
gen_kwargs['pad_token_id'] = preprocessor['text'].pad_token_id
if hasattr(cfg.model_args, 'gen_kwargs_set_bos_token_id') and cfg.model_args.gen_kwargs_set_bos_token_id:
gen_kwargs['bos_token_id'] = preprocessor['text'].bos_token_id
if hasattr(cfg.model_args, 'gen_kwargs_set_eos_token_id') and cfg.model_args.gen_kwargs_set_eos_token_id:
gen_kwargs['eos_token_id'] = preprocessor['text'].eos_token_id
# Evaluation
if training_args.do_eval:
if hasattr(trainer, '_test_collator') and hasattr(trainer, '_eval_collator') \
and trainer._test_collator != trainer._eval_collator: # noqa
warnings.warn('[WARNING!!!] use different collator for eval and test. but do_eval and '
'do_predict both use trainer.predict (i.e. only test_collator is used.)')
eval_results = trainer.predict(dataset['validation'], metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", eval_results.metrics) # noqa
trainer.save_metrics("eval", eval_results.metrics) # noqa
trainer.save_prediction(eval_results, file_key_prefix='eval')
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset['test'], metric_key_prefix="test", **gen_kwargs)
trainer.log_metrics("test", predict_results.metrics) # noqa
trainer.save_metrics("test", predict_results.metrics) # noqa
trainer.save_prediction(predict_results, file_key_prefix='test')
# Multi Predict
if training_args.do_multi_predict:
old_compute_metrics = trainer.compute_metrics
multitest = dataset['multitest']
multitest = typing.cast(dict, multitest)
for _idx, (k, item) in enumerate(multitest.items()):
print(f'processing multitest set {_idx}/{len(multitest)}: {k}')
_ds = item['dataset']
_compute_metrics = item['compute_metric']
_prefix = f"multitest_{k}"
trainer.compute_metrics = _compute_metrics
_pred_results = trainer.predict(_ds, metric_key_prefix=_prefix, **gen_kwargs)
trainer.log_metrics(_prefix, _pred_results.metrics) # noqa
trainer.save_metrics(_prefix, _pred_results.metrics) # noqa
trainer.save_prediction(_pred_results, file_key_prefix=_prefix)
trainer.compute_metrics = old_compute_metrics
# noinspection PyUnusedLocal
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()