Spaces:
Sleeping
Sleeping
| # Copyright (C) 2021-2024, Mindee. | |
| # This program is licensed under the Apache License 2.0. | |
| # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
| from copy import deepcopy | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import tensorflow as tf | |
| from tensorflow.keras import Model, layers | |
| from doctr.datasets import VOCABS | |
| from ...classification import vit_b, vit_s | |
| from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params | |
| from .base import _ViTSTR, _ViTSTRPostProcessor | |
| __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] | |
| default_cfgs: Dict[str, Dict[str, Any]] = { | |
| "vitstr_small": { | |
| "mean": (0.694, 0.695, 0.693), | |
| "std": (0.299, 0.296, 0.301), | |
| "input_shape": (32, 128, 3), | |
| "vocab": VOCABS["french"], | |
| "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_small-358fab2e.zip&src=0", | |
| }, | |
| "vitstr_base": { | |
| "mean": (0.694, 0.695, 0.693), | |
| "std": (0.299, 0.296, 0.301), | |
| "input_shape": (32, 128, 3), | |
| "vocab": VOCABS["french"], | |
| "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_base-2889159a.zip&src=0", | |
| }, | |
| } | |
| class ViTSTR(_ViTSTR, Model): | |
| """Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and | |
| Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_. | |
| Args: | |
| ---- | |
| feature_extractor: the backbone serving as feature extractor | |
| vocab: vocabulary used for encoding | |
| embedding_units: number of embedding units | |
| max_length: maximum word length handled by the model | |
| dropout_prob: dropout probability for the encoder and decoder | |
| input_shape: input shape of the image | |
| exportable: onnx exportable returns only logits | |
| cfg: dictionary containing information about the model | |
| """ | |
| _children_names: List[str] = ["feat_extractor", "postprocessor"] | |
| def __init__( | |
| self, | |
| feature_extractor, | |
| vocab: str, | |
| embedding_units: int, | |
| max_length: int = 32, | |
| dropout_prob: float = 0.0, | |
| input_shape: Tuple[int, int, int] = (32, 128, 3), # different from paper | |
| exportable: bool = False, | |
| cfg: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.vocab = vocab | |
| self.exportable = exportable | |
| self.cfg = cfg | |
| self.max_length = max_length + 2 # +2 for SOS and EOS | |
| self.feat_extractor = feature_extractor | |
| self.head = layers.Dense(len(self.vocab) + 1, name="head") # +1 for EOS | |
| self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab) | |
| def compute_loss( | |
| model_output: tf.Tensor, | |
| gt: tf.Tensor, | |
| seq_len: List[int], | |
| ) -> tf.Tensor: | |
| """Compute categorical cross-entropy loss for the model. | |
| Sequences are masked after the EOS character. | |
| Args: | |
| ---- | |
| model_output: predicted logits of the model | |
| gt: the encoded tensor with gt labels | |
| seq_len: lengths of each gt word inside the batch | |
| Returns: | |
| ------- | |
| The loss of the model on the batch | |
| """ | |
| # Input length : number of steps | |
| input_len = tf.shape(model_output)[1] | |
| # Add one for additional <eos> token (sos disappear in shift!) | |
| seq_len = tf.cast(seq_len, tf.int32) + 1 | |
| # One-hot gt labels | |
| oh_gt = tf.one_hot(gt, depth=model_output.shape[2]) | |
| # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! | |
| # The "masked" first gt char is <sos>. | |
| cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output) | |
| # Compute mask | |
| mask_values = tf.zeros_like(cce) | |
| mask_2d = tf.sequence_mask(seq_len, input_len) | |
| masked_loss = tf.where(mask_2d, cce, mask_values) | |
| ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype)) | |
| return tf.expand_dims(ce_loss, axis=1) | |
| def call( | |
| self, | |
| x: tf.Tensor, | |
| target: Optional[List[str]] = None, | |
| return_model_output: bool = False, | |
| return_preds: bool = False, | |
| **kwargs: Any, | |
| ) -> Dict[str, Any]: | |
| features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model) | |
| if target is not None: | |
| gt, seq_len = self.build_target(target) | |
| seq_len = tf.cast(seq_len, tf.int32) | |
| if kwargs.get("training", False) and target is None: | |
| raise ValueError("Need to provide labels during training") | |
| features = features[:, : self.max_length] # (batch_size, max_length, d_model) | |
| B, N, E = features.shape | |
| features = tf.reshape(features, (B * N, E)) | |
| logits = tf.reshape( | |
| self.head(features, **kwargs), (B, N, len(self.vocab) + 1) | |
| ) # (batch_size, max_length, vocab + 1) | |
| decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token | |
| out: Dict[str, tf.Tensor] = {} | |
| if self.exportable: | |
| out["logits"] = decoded_features | |
| return out | |
| if return_model_output: | |
| out["out_map"] = decoded_features | |
| if target is None or return_preds: | |
| # Post-process boxes | |
| out["preds"] = self.postprocessor(decoded_features) | |
| if target is not None: | |
| out["loss"] = self.compute_loss(decoded_features, gt, seq_len) | |
| return out | |
| class ViTSTRPostProcessor(_ViTSTRPostProcessor): | |
| """Post processor for ViTSTR architecture | |
| Args: | |
| ---- | |
| vocab: string containing the ordered sequence of supported characters | |
| """ | |
| def __call__( | |
| self, | |
| logits: tf.Tensor, | |
| ) -> List[Tuple[str, float]]: | |
| # compute pred with argmax for attention models | |
| out_idxs = tf.math.argmax(logits, axis=2) | |
| preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1) | |
| # decode raw output of the model with tf_label_to_idx | |
| out_idxs = tf.cast(out_idxs, dtype="int32") | |
| embedding = tf.constant(self._embedding, dtype=tf.string) | |
| decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) | |
| decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>") | |
| decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] | |
| word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] | |
| # compute probabilties for each word up to the EOS token | |
| probs = [ | |
| preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0 | |
| for i, word in enumerate(word_values) | |
| ] | |
| return list(zip(word_values, probs)) | |
| def _vitstr( | |
| arch: str, | |
| pretrained: bool, | |
| backbone_fn, | |
| input_shape: Optional[Tuple[int, int, int]] = None, | |
| **kwargs: Any, | |
| ) -> ViTSTR: | |
| # Patch the config | |
| _cfg = deepcopy(default_cfgs[arch]) | |
| _cfg["input_shape"] = input_shape or _cfg["input_shape"] | |
| _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"]) | |
| patch_size = kwargs.get("patch_size", (4, 8)) | |
| kwargs["vocab"] = _cfg["vocab"] | |
| # Feature extractor | |
| feat_extractor = backbone_fn( | |
| # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch | |
| pretrained=False, | |
| input_shape=_cfg["input_shape"], | |
| patch_size=patch_size, | |
| include_top=False, | |
| ) | |
| kwargs.pop("patch_size", None) | |
| kwargs.pop("pretrained_backbone", None) | |
| # Build the model | |
| model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) | |
| # Load pretrained parameters | |
| if pretrained: | |
| load_pretrained_params(model, default_cfgs[arch]["url"]) | |
| return model | |
| def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR: | |
| """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" | |
| <https://arxiv.org/pdf/2105.08582.pdf>`_. | |
| >>> import tensorflow as tf | |
| >>> from doctr.models import vitstr_small | |
| >>> model = vitstr_small(pretrained=False) | |
| >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) | |
| >>> out = model(input_tensor) | |
| Args: | |
| ---- | |
| pretrained (bool): If True, returns a model pre-trained on our text recognition dataset | |
| **kwargs: keyword arguments of the ViTSTR architecture | |
| Returns: | |
| ------- | |
| text recognition architecture | |
| """ | |
| return _vitstr( | |
| "vitstr_small", | |
| pretrained, | |
| vit_s, | |
| embedding_units=384, | |
| patch_size=(4, 8), | |
| **kwargs, | |
| ) | |
| def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR: | |
| """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" | |
| <https://arxiv.org/pdf/2105.08582.pdf>`_. | |
| >>> import tensorflow as tf | |
| >>> from doctr.models import vitstr_base | |
| >>> model = vitstr_base(pretrained=False) | |
| >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) | |
| >>> out = model(input_tensor) | |
| Args: | |
| ---- | |
| pretrained (bool): If True, returns a model pre-trained on our text recognition dataset | |
| **kwargs: keyword arguments of the ViTSTR architecture | |
| Returns: | |
| ------- | |
| text recognition architecture | |
| """ | |
| return _vitstr( | |
| "vitstr_base", | |
| pretrained, | |
| vit_b, | |
| embedding_units=768, | |
| patch_size=(4, 8), | |
| **kwargs, | |
| ) | |