import logging import os import sys import warnings from dataclasses import dataclass, field from typing import Optional import wandb import datasets import evaluate from datasets import load_dataset from trainer_qa import QuestionAnsweringTrainer from utils_qa import postprocess_qa_predictions import transformers from transformers import ( AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, DataCollatorWithPadding, EvalPrediction, HfArgumentParser, PreTrainedTokenizerFast, TrainingArguments, default_data_collator, set_seed, ) from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version from ray import tune from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler @dataclass class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, ) @dataclass class DataTrainingArguments: dataset_name: Optional[str] = field( default="squad", metadata={"help": "The name of the dataset to use (via the datasets library)."} ) train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) test_file: Optional[str] = field( default=None, metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) preprocessing_num_workers: Optional[int] = field( default=10, metadata={"help": "The number of processes to use for the preprocessing."}, ) max_seq_length: int = field( default=384, metadata={ "help": ( "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded." ) }, ) pad_to_max_length: bool = field( default=True, metadata={ "help": ( "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when" " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)." ) }, ) version_2_with_negative: bool = field( default=False, metadata={"help": "If true, some of the examples do not have an answer."} ) null_score_diff_threshold: float = field( default=0.0, metadata={ "help": ( "The threshold used to select the null answer: if the best answer has a score that is less than " "the score of the null answer minus this threshold, the null answer is selected for this example. " "Only useful when `version_2_with_negative=True`." ) }, ) doc_stride: int = field( default=128, metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, ) n_best_size: int = field( default=20, metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, ) max_answer_length: int = field( default=30, metadata={ "help": ( "The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another." ) }, ) def main(): wandb.init( project="QA_test", ) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) # Set seed before initializing model. set_seed(training_args.seed) if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, cache_dir=model_args.cache_dir, split="train[:20]" ) raw_datasets = raw_datasets.train_test_split(test_size=0.2) raw_datasets["validation"] = load_dataset( data_args.dataset_name, cache_dir=model_args.cache_dir, split="validation" ) print(raw_datasets) tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=True, ) def get_model(): return AutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) # Preprocessing the datasets. # Preprocessing is slightly different for training and evaluation. if training_args.do_train: column_names = raw_datasets["train"].column_names elif training_args.do_eval: column_names = raw_datasets["validation"].column_names else: column_names = raw_datasets["test"].column_names question_column_name = "question" if "question" in column_names else column_names[0] context_column_name = "context" if "context" in column_names else column_names[1] answer_column_name = "answers" if "answers" in column_names else column_names[2] # Padding side determines if we do (question|context) or (context|question). pad_on_right = tokenizer.padding_side == "right" max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Training preprocessing def prepare_train_features(examples): examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length" if data_args.pad_to_max_length else False, ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # The offset mappings will give us a map from token to character position in the original context. This will # help us compute the start_positions and end_positions. offset_mapping = tokenized_examples.pop("offset_mapping") # Let's label those examples! tokenized_examples["start_positions"] = [] tokenized_examples["end_positions"] = [] for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] cls_index = input_ids.index(tokenizer.cls_token_id) # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] answers = examples[answer_column_name][sample_index] # If no answers are given, set the cls_index as answer. if len(answers["answer_start"]) == 0: tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Start/end character index of the answer in the text. start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) # Start token index of the current span in the text. token_start_index = 0 while sequence_ids[token_start_index] != (1 if pad_on_right else 0): token_start_index += 1 # End token index of the current span in the text. token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != (1 if pad_on_right else 0): token_end_index -= 1 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Otherwise move the token_start_index and token_end_index to the two ends of the answer. # Note: we could go after the last offset if the answer is the last word (edge case). while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: token_start_index += 1 tokenized_examples["start_positions"].append(token_start_index - 1) while offsets[token_end_index][1] >= end_char: token_end_index -= 1 tokenized_examples["end_positions"].append(token_end_index + 1) return tokenized_examples if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] with training_args.main_process_first(desc="train dataset map pre-processing"): train_dataset = train_dataset.map( prepare_train_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on train dataset", ) # Validation preprocessing def prepare_validation_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the # truncation of the context fail (the tokenized question will take a lots of space). So we remove that # left whitespace examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length" if data_args.pad_to_max_length else False, ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the # corresponding example_id and we will store the offset mappings. tokenized_examples["example_id"] = [] for i in range(len(tokenized_examples["input_ids"])): # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) context_index = 1 if pad_on_right else 0 # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] tokenized_examples["example_id"].append(examples["id"][sample_index]) # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token # position is part of the context or not. tokenized_examples["offset_mapping"][i] = [ (o if sequence_ids[k] == context_index else None) for k, o in enumerate(tokenized_examples["offset_mapping"][i]) ] return tokenized_examples if training_args.do_eval: if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_examples = raw_datasets["validation"] with training_args.main_process_first(desc="validation dataset map pre-processing"): eval_dataset = eval_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on validation dataset", ) if training_args.do_predict: if "test" not in raw_datasets: raise ValueError("--do_predict requires a test dataset") predict_examples = raw_datasets["test"] # Predict Feature Creation with training_args.main_process_first(desc="prediction dataset map pre-processing"): predict_dataset = predict_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on prediction dataset", ) data_collator = ( default_data_collator if data_args.pad_to_max_length else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) ) # Post-processing: def post_processing_function(examples, features, predictions, stage="eval"): # Post-processing: we match the start logits and end logits to answers in the original context. predictions = postprocess_qa_predictions( examples=examples, features=features, predictions=predictions, version_2_with_negative=data_args.version_2_with_negative, n_best_size=data_args.n_best_size, max_answer_length=data_args.max_answer_length, null_score_diff_threshold=data_args.null_score_diff_threshold, output_dir=training_args.output_dir, prefix=stage, ) # Format the result to the format the metric expects. if data_args.version_2_with_negative: formatted_predictions = [ {"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() ] else: formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()] references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = evaluate.load( "squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir ) def compute_metrics(p): # print(p) # breakpoint() return metric.compute(predictions=p.predictions, references=p.label_ids) training_args = TrainingArguments( output_dir=".", learning_rate=1e-5, # config do_train=True, do_eval=True, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, num_train_epochs=2, # config max_steps=-1, per_device_train_batch_size=16, # config per_device_eval_batch_size=16, # config warmup_steps=0, weight_decay=0.1, # config logging_dir="./logs", skip_memory_metrics=True, report_to="wandb", disable_tqdm=True, metric_for_best_model="f1" ) trainer = QuestionAnsweringTrainer( model_init=get_model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, eval_examples=eval_examples if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, post_process_function=post_processing_function, compute_metrics=compute_metrics, ) tune_config = { "per_device_train_batch_size": 32, "per_device_eval_batch_size": 32, "num_train_epochs": 1, "learning_rate": tune.grid_search([2e-5]) } scheduler = ASHAScheduler(metric="eval_f1", mode="max", time_attr="training_iteration", max_t=50, grace_period=10, reduction_factor=3, brackets=1) reporter = CLIReporter( parameter_columns={ "weight_decay": "w_decay", "learning_rate": "lr", "per_device_train_batch_size": "train_bs/gpu", "num_train_epochs": "num_epochs", }, metric_columns=["eval_exact", "eval_f1"], ) import copy def compute_objective(metrics): metrics = copy.deepcopy(metrics) loss = metrics.pop("eval_loss", None) _ = metrics.pop("epoch", None) return metrics["eval_f1"] results = trainer.hyperparameter_search( hp_space=lambda _: tune_config, backend="ray", n_trials=1, scheduler=scheduler, keep_checkpoints_num=1, progress_reporter=reporter, local_dir="./runs", log_to_file=True, direction="maximize", checkpoint_score_attr="training_iteration", compute_objective=compute_objective, ) best_checkpoint = results.run_summary.get_best_checkpoint(results.run_summary.get_best_trial(metric="eval_f1", mode="max"), metric="eval_f1", mode="max").path + "/checkpoint-1" model_retrain = AutoModelForQuestionAnswering.from_pretrained(best_checkpoint) # Prediction if training_args.do_predict: results = trainer.predict(predict_dataset, predict_examples) metrics = results.metrics trainer.log_metrics("predict", metrics) trainer.save_metrics("predict", metrics) kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name kwargs["dataset"] = data_args.dataset_name trainer.push_to_hub(**kwargs) if __name__ == "__main__": main()