| import numpy as np | |
| import datetime | |
| def print_batch(tokenizer, batch, n, header=None): | |
| ''' | |
| print a batch of tokens. Used mainly for debugging | |
| Parameters | |
| ------------ | |
| tokenizer : Tokenizer (https://huggingface.co/docs/tokenizers/python/latest/api/reference.html#tokenizers.Tokenizer) | |
| batch : List of List[int] | |
| n : int | |
| number of sentences to print from the batch | |
| header : str | |
| header of the batch printed before the sentences | |
| ''' | |
| print(f'=== {header or "Batch"} ===') | |
| print(tokenizer.batch_decode(batch[:n], skip_special_tokens=True)) | |
| print('...\n' if n < len(batch) else '') | |
| def flat_accuracy(preds, labels): | |
| pred_flat = np.argmax(preds, axis=1).flatten() | |
| labels_flat = labels.flatten() | |
| return np.sum(pred_flat == labels_flat) / len(labels_flat) | |
| def format_time(elapsed): | |
| ''' | |
| Takes a time in seconds and returns a string hh:mm:ss | |
| ''' | |
| elapsed_rounded = int(round((elapsed))) | |
| return str(datetime.timedelta(seconds=elapsed_rounded)) |