# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library to facilitate TFLite model conversion.""" import functools from typing import Iterator, List, Optional from absl import logging import tensorflow as tf, tf_keras from official.core import base_task from official.core import config_definitions as cfg from official.vision import configs from official.vision import tasks def create_representative_dataset( params: cfg.ExperimentConfig, task: Optional[base_task.Task] = None) -> tf.data.Dataset: """Creates a tf.data.Dataset to load images for representative dataset. Args: params: An ExperimentConfig. task: An optional task instance. If it is None, task will be built according to the task type in params. Returns: A tf.data.Dataset instance. Raises: ValueError: If task is not supported. """ if task is None: if isinstance(params.task, configs.image_classification.ImageClassificationTask): task = tasks.image_classification.ImageClassificationTask(params.task) elif isinstance(params.task, configs.retinanet.RetinaNetTask): task = tasks.retinanet.RetinaNetTask(params.task) elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask): task = tasks.maskrcnn.MaskRCNNTask(params.task) elif isinstance(params.task, configs.semantic_segmentation.SemanticSegmentationTask): task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task) else: raise ValueError('Task {} not supported.'.format(type(params.task))) # Ensure batch size is 1 for TFLite model. params.task.train_data.global_batch_size = 1 params.task.train_data.dtype = 'float32' logging.info('Task config: %s', params.task.as_dict()) return task.build_inputs(params=params.task.train_data) def representative_dataset( params: cfg.ExperimentConfig, task: Optional[base_task.Task] = None, calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]: """"Creates representative dataset for input calibration. Args: params: An ExperimentConfig. task: An optional task instance. If it is None, task will be built according to the task type in params. calibration_steps: The steps to do calibration. Yields: An input image tensor. """ dataset = create_representative_dataset(params=params, task=task) for image, _ in dataset.take(calibration_steps): # Skip images that do not have 3 channels. if image.shape[-1] != 3: continue yield [image] def convert_tflite_model( saved_model_dir: Optional[str] = None, concrete_function: Optional[tf.types.experimental.ConcreteFunction] = None, model: Optional[tf.Module] = None, quant_type: Optional[str] = None, params: Optional[cfg.ExperimentConfig] = None, task: Optional[base_task.Task] = None, calibration_steps: Optional[int] = 2000, denylisted_ops: Optional[List[str]] = None, ) -> 'bytes': """Converts and returns a TFLite model. Args: saved_model_dir: The directory to the SavedModel. concrete_function: An optional concrete function to be exported. model: An optional tf_keras.Model instance. If both `saved_model_dir` and `concrete_function` are not available, convert this model to TFLite. quant_type: The post training quantization (PTQ) method. It can be one of `default` (dynamic range), `fp16` (float16), `int8` (integer wih float fallback), `int8_full` (integer only) and None (no quantization). params: An optional ExperimentConfig to load and preprocess input images to do calibration for integer quantization. task: An optional task instance. If it is None, task will be built according to the task type in params. calibration_steps: The steps to do calibration. denylisted_ops: A list of strings containing ops that are excluded from integer quantization. Returns: A converted TFLite model with optional PTQ. Raises: ValueError: If `representative_dataset_path` is not present if integer quantization is requested, or `saved_model_dir`, `concrete_function` or `model` are not provided. """ if saved_model_dir: converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) elif concrete_function is not None: converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_function] ) elif model is not None: converter = tf.lite.TFLiteConverter.from_keras_model(model) else: raise ValueError( '`saved_model_dir`, `model` or `concrete_function` must be specified.' ) if quant_type: if quant_type.startswith('int8'): converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = functools.partial( representative_dataset, params=params, task=task, calibration_steps=calibration_steps) if quant_type.startswith('int8_full'): converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ] if quant_type == 'int8_full': converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 if quant_type == 'int8_full_int8_io': converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 if denylisted_ops: debug_options = tf.lite.experimental.QuantizationDebugOptions( denylisted_ops=denylisted_ops) debugger = tf.lite.experimental.QuantizationDebugger( converter=converter, debug_dataset=functools.partial( representative_dataset, params=params, calibration_steps=calibration_steps), debug_options=debug_options) debugger.run() return debugger.get_nondebug_quantized_model() elif quant_type == 'uint8': converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.default_ranges_stats = (-10, 10) converter.inference_type = tf.uint8 converter.quantized_input_stats = {'input_placeholder': (0., 1.)} elif quant_type == 'fp16': converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_types = [tf.float16] elif quant_type in ('default', 'qat_fp32_io'): converter.optimizations = [tf.lite.Optimize.DEFAULT] elif quant_type == 'qat': converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.inference_input_type = tf.uint8 # or tf.int8 converter.inference_output_type = tf.uint8 # or tf.int8 else: raise ValueError(f'quantization type {quant_type} is not supported.') return converter.convert()