Spaces:
Sleeping
Sleeping
| """ | |
| Donut | |
| Copyright (c) 2022-present NAVER Corp. | |
| MIT License | |
| """ | |
| import math | |
| import random | |
| import re | |
| from pathlib import Path | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| from nltk import edit_distance | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.utils.data import DataLoader | |
| from donut import DonutConfig, DonutModel | |
| class DonutModelPLModule(pl.LightningModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| if self.config.get("pretrained_model_name_or_path", False): | |
| self.model = DonutModel.from_pretrained( | |
| self.config.pretrained_model_name_or_path, | |
| input_size=self.config.input_size, | |
| max_length=self.config.max_length, | |
| align_long_axis=self.config.align_long_axis, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| else: | |
| self.model = DonutModel( | |
| config=DonutConfig( | |
| input_size=self.config.input_size, | |
| max_length=self.config.max_length, | |
| align_long_axis=self.config.align_long_axis, | |
| # with DonutConfig, the architecture customization is available, e.g., | |
| # encoder_layer=[2,2,14,2], decoder_layer=4, ... | |
| ) | |
| ) | |
| self.pytorch_lightning_version_is_1 = int(pl.__version__[0]) < 2 | |
| self.num_of_loaders = len(self.config.dataset_name_or_paths) | |
| def training_step(self, batch, batch_idx): | |
| image_tensors, decoder_input_ids, decoder_labels = list(), list(), list() | |
| for batch_data in batch: | |
| image_tensors.append(batch_data[0]) | |
| decoder_input_ids.append(batch_data[1][:, :-1]) | |
| decoder_labels.append(batch_data[2][:, 1:]) | |
| image_tensors = torch.cat(image_tensors) | |
| decoder_input_ids = torch.cat(decoder_input_ids) | |
| decoder_labels = torch.cat(decoder_labels) | |
| loss = self.model(image_tensors, decoder_input_ids, decoder_labels)[0] | |
| self.log_dict({"train_loss": loss}, sync_dist=True) | |
| if not self.pytorch_lightning_version_is_1: | |
| self.log('loss', loss, prog_bar=True) | |
| return loss | |
| def on_validation_epoch_start(self) -> None: | |
| super().on_validation_epoch_start() | |
| self.validation_step_outputs = [[] for _ in range(self.num_of_loaders)] | |
| return | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| image_tensors, decoder_input_ids, prompt_end_idxs, answers = batch | |
| decoder_prompts = pad_sequence( | |
| [input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)], | |
| batch_first=True, | |
| ) | |
| preds = self.model.inference( | |
| image_tensors=image_tensors, | |
| prompt_tensors=decoder_prompts, | |
| return_json=False, | |
| return_attentions=False, | |
| )["predictions"] | |
| scores = list() | |
| for pred, answer in zip(preds, answers): | |
| pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred) | |
| answer = re.sub(r"<.*?>", "", answer, count=1) | |
| answer = answer.replace(self.model.decoder.tokenizer.eos_token, "") | |
| scores.append(edit_distance(pred, answer) / max(len(pred), len(answer))) | |
| if self.config.get("verbose", False) and len(scores) == 1: | |
| self.print(f"Prediction: {pred}") | |
| self.print(f" Answer: {answer}") | |
| self.print(f" Normed ED: {scores[0]}") | |
| self.validation_step_outputs[dataloader_idx].append(scores) | |
| return scores | |
| def on_validation_epoch_end(self): | |
| assert len(self.validation_step_outputs) == self.num_of_loaders | |
| cnt = [0] * self.num_of_loaders | |
| total_metric = [0] * self.num_of_loaders | |
| val_metric = [0] * self.num_of_loaders | |
| for i, results in enumerate(self.validation_step_outputs): | |
| for scores in results: | |
| cnt[i] += len(scores) | |
| total_metric[i] += np.sum(scores) | |
| val_metric[i] = total_metric[i] / cnt[i] | |
| val_metric_name = f"val_metric_{i}th_dataset" | |
| self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True) | |
| self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True) | |
| def configure_optimizers(self): | |
| max_iter = None | |
| if int(self.config.get("max_epochs", -1)) > 0: | |
| assert len(self.config.train_batch_sizes) == 1, "Set max_epochs only if the number of datasets is 1" | |
| max_iter = (self.config.max_epochs * self.config.num_training_samples_per_epoch) / ( | |
| self.config.train_batch_sizes[0] * torch.cuda.device_count() * self.config.get("num_nodes", 1) | |
| ) | |
| if int(self.config.get("max_steps", -1)) > 0: | |
| max_iter = min(self.config.max_steps, max_iter) if max_iter is not None else self.config.max_steps | |
| assert max_iter is not None | |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr) | |
| scheduler = { | |
| "scheduler": self.cosine_scheduler(optimizer, max_iter, self.config.warmup_steps), | |
| "name": "learning_rate", | |
| "interval": "step", | |
| } | |
| return [optimizer], [scheduler] | |
| def cosine_scheduler(optimizer, training_steps, warmup_steps): | |
| def lr_lambda(current_step): | |
| if current_step < warmup_steps: | |
| return current_step / max(1, warmup_steps) | |
| progress = current_step - warmup_steps | |
| progress /= max(1, training_steps - warmup_steps) | |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | |
| return LambdaLR(optimizer, lr_lambda) | |
| def on_save_checkpoint(self, checkpoint): | |
| save_path = Path(self.config.result_path) / self.config.exp_name / self.config.exp_version | |
| self.model.save_pretrained(save_path) | |
| self.model.decoder.tokenizer.save_pretrained(save_path) | |
| class DonutDataPLModule(pl.LightningDataModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.train_batch_sizes = self.config.train_batch_sizes | |
| self.val_batch_sizes = self.config.val_batch_sizes | |
| self.train_datasets = [] | |
| self.val_datasets = [] | |
| self.g = torch.Generator() | |
| self.g.manual_seed(self.config.seed) | |
| def train_dataloader(self): | |
| loaders = list() | |
| for train_dataset, batch_size in zip(self.train_datasets, self.train_batch_sizes): | |
| loaders.append( | |
| DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| num_workers=self.config.num_workers, | |
| pin_memory=True, | |
| worker_init_fn=self.seed_worker, | |
| generator=self.g, | |
| shuffle=True, | |
| ) | |
| ) | |
| return loaders | |
| def val_dataloader(self): | |
| loaders = list() | |
| for val_dataset, batch_size in zip(self.val_datasets, self.val_batch_sizes): | |
| loaders.append( | |
| DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| pin_memory=True, | |
| shuffle=False, | |
| ) | |
| ) | |
| return loaders | |
| def seed_worker(wordker_id): | |
| worker_seed = torch.initial_seed() % 2 ** 32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |