Spaces:
Runtime error
Runtime error
File size: 13,039 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification decoder and parser."""
from typing import Any, Dict, List, Optional, Tuple
# Import libraries
import tensorflow as tf, tf_keras
from official.vision.configs import common
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import augment
from official.vision.ops import preprocess_ops
DEFAULT_IMAGE_FIELD_KEY = 'image/encoded'
DEFAULT_LABEL_FIELD_KEY = 'image/class/label'
class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
def __init__(self,
image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
is_multilabel: bool = False,
keys_to_features: Optional[Dict[str, Any]] = None):
if not keys_to_features:
keys_to_features = {
image_field_key:
tf.io.FixedLenFeature((), tf.string, default_value=''),
}
if is_multilabel:
keys_to_features.update(
{label_field_key: tf.io.VarLenFeature(dtype=tf.int64)})
else:
keys_to_features.update({
label_field_key:
tf.io.FixedLenFeature((), tf.int64, default_value=-1)
})
self._keys_to_features = keys_to_features
def decode(self, serialized_example):
return tf.io.parse_single_example(serialized_example,
self._keys_to_features)
class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
output_size: List[int],
num_classes: float,
image_field_key: str = DEFAULT_IMAGE_FIELD_KEY,
label_field_key: str = DEFAULT_LABEL_FIELD_KEY,
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True,
aug_crop: Optional[bool] = True,
aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
is_multilabel: bool = False,
dtype: str = 'float32',
crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0),
center_crop_fraction: Optional[
float] = preprocess_ops.CENTER_CROP_FRACTION,
tf_resize_method: str = 'bilinear',
three_augment: bool = False):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
image_field_key: `str`, the key name to encoded image or decoded image
matrix in tf.Example.
label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
generate random scale factor for brightness, contrast and saturation.
See `preprocess_ops.color_jitter` for more details.
random_erasing: if not None, augment input image by random erasing. See
`augment.RandomErasing` for more details.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
crop_area_range: An optional `tuple` of (min_area, max_area) for image
random crop function to constraint crop operation. The cropped areas
of the image must contain a fraction of the input image within this
range. The default area range is (0.08, 1.0).
https://arxiv.org/abs/2204.07118.
center_crop_fraction: center_crop_fraction.
tf_resize_method: A `str`, interpolation method for resizing image.
three_augment: A bool, whether to apply three augmentations.
"""
self._output_size = output_size
self._aug_rand_hflip = aug_rand_hflip
self._aug_crop = aug_crop
self._num_classes = num_classes
self._image_field_key = image_field_key
if dtype == 'float32':
self._dtype = tf.float32
elif dtype == 'float16':
self._dtype = tf.float16
elif dtype == 'bfloat16':
self._dtype = tf.bfloat16
else:
raise ValueError('dtype {!r} is not supported!'.format(dtype))
if aug_type:
if aug_type.type == 'autoaug':
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
self._augmenter = augment.RandAugment(
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
self._label_field_key = label_field_key
self._color_jitter = color_jitter
if random_erasing:
self._random_erasing = augment.RandomErasing(
probability=random_erasing.probability,
min_area=random_erasing.min_area,
max_area=random_erasing.max_area,
min_aspect=random_erasing.min_aspect,
max_aspect=random_erasing.max_aspect,
min_count=random_erasing.min_count,
max_count=random_erasing.max_count,
trials=random_erasing.trials)
else:
self._random_erasing = None
self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only
self._crop_area_range = crop_area_range
self._center_crop_fraction = center_crop_fraction
self._tf_resize_method = tf_resize_method
self._three_augment = three_augment
def _parse_train_data(self, decoded_tensors):
"""Parses data for training."""
image = self._parse_train_image(decoded_tensors)
label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32)
if self._is_multilabel:
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
def _parse_eval_data(self, decoded_tensors):
"""Parses data for evaluation."""
image = self._parse_eval_image(decoded_tensors)
label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32)
if self._is_multilabel:
if isinstance(label, tf.sparse.SparseTensor):
label = tf.sparse.to_dense(label)
label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0)
return image, label
def _parse_train_image(self, decoded_tensors):
"""Parses image data for training."""
image_bytes = decoded_tensors[self._image_field_key]
require_decoding = (
not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string
)
if (
require_decoding
and self._decode_jpeg_only
and self._aug_crop
):
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Crops image.
cropped_image = preprocess_ops.random_crop_image_v2(
image_bytes, image_shape, area_range=self._crop_area_range)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)),
lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape),
lambda: cropped_image)
else:
if require_decoding:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
else:
# Already decoded image matrix
image = image_bytes
# Crops image.
if self._aug_crop:
cropped_image = preprocess_ops.random_crop_image(
image, area_range=self._crop_area_range)
image = tf.cond(
tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))),
lambda: preprocess_ops.center_crop_image(image),
lambda: cropped_image)
if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image)
# Color jitter.
if self._color_jitter > 0:
image = preprocess_ops.color_jitter(image, self._color_jitter,
self._color_jitter,
self._color_jitter)
# Resizes image.
image = tf.image.resize(
image, self._output_size, method=self._tf_resize_method)
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Apply autoaug or randaug.
if self._augmenter is not None:
image = self._augmenter.distort(image)
# Three augmentation
if self._three_augment:
image = augment.AutoAugment(
augmentation_name='deit3_three_augment',
translate_const=20,
).distort(image)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
image = self._random_erasing.distort(image)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
return image
def _parse_eval_image(self, decoded_tensors):
"""Parses image data for evaluation."""
image_bytes = decoded_tensors[self._image_field_key]
require_decoding = (
not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string
)
if (
require_decoding
and self._decode_jpeg_only
and self._aug_crop
):
image_shape = tf.image.extract_jpeg_shape(image_bytes)
# Center crops.
image = preprocess_ops.center_crop_image_v2(
image_bytes, image_shape, self._center_crop_fraction)
else:
if require_decoding:
# Decodes image.
image = tf.io.decode_image(image_bytes, channels=3)
image.set_shape([None, None, 3])
else:
# Already decoded image matrix
image = image_bytes
# Center crops.
if self._aug_crop:
image = preprocess_ops.center_crop_image(
image, self._center_crop_fraction)
image = tf.image.resize(
image, self._output_size, method=self._tf_resize_method)
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
return image
def parse_train_image(self, decoded_tensors: Dict[str,
tf.Tensor]) -> tf.Tensor:
"""Public interface for parsing image data for training."""
return self._parse_train_image(decoded_tensors)
@classmethod
def inference_fn(cls,
image: tf.Tensor,
input_image_size: List[int],
num_channels: int = 3) -> tf.Tensor:
"""Builds image model inputs for serving."""
image = tf.cast(image, dtype=tf.float32)
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize(
image, input_image_size, method=tf.image.ResizeMethod.BILINEAR)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
image.set_shape(input_image_size + [num_channels])
return image
|