diff options
author | 2018-05-31 18:29:32 -0700 | |
---|---|---|
committer | 2018-05-31 18:29:32 -0700 | |
commit | 7a9f1e6ec8ae53cd4321e819bd2343429a5ea9eb (patch) | |
tree | 430e7ffc7d8157f2e0b4bb6d814e2eb0d375d31b /tensorflow/contrib/lite/python | |
parent | e365deab1333005c8aa186632f160c1bfd4485f8 (diff) | |
parent | 2e272dbca6600991599e55a7ff7cfa668b8403aa (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/BUILD | 19 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert.py | 63 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model.py | 118 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model_test.py | 55 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 218 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 241 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/tflite_convert.py | 324 |
7 files changed, 909 insertions, 129 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index a40e512045..7e6ff6c0a8 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -36,6 +36,16 @@ py_test( ], ) +py_binary( + name = "tflite_convert", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + ], +) + py_library( name = "lite", srcs = ["lite.py"], @@ -125,6 +135,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":convert", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", "//tensorflow/python:platform", @@ -164,11 +175,3 @@ py_test( "//tensorflow/python/saved_model", ], ) - -# Transitive dependencies of this target will be included in the pip package. -py_library( - name = "tf_lite_py_pip", - deps = [ - ":convert_saved_model", - ], -) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c0926d2f33..0819475240 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -115,11 +115,15 @@ def toco_convert(input_data, input_tensors, output_tensors, inference_type=lite_constants.FLOAT, + inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, output_format=lite_constants.TFLITE, quantized_input_stats=None, + default_ranges_stats=None, drop_control_dependency=True, - allow_custom_ops=False): + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False): """Convert a model using TOCO from `input_format` to `output_format`. Typically this is to convert from TensorFlow GraphDef to TFLite, in which @@ -130,18 +134,41 @@ def toco_convert(input_data, input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + input_format: Type of data to read Currently must be + `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default None) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) Returns: - The converted data. For example if tflite was the destination, then + The converted data. For example if TFLite was the destination, then this will be a tflite flatbuffer in a bytes array. Raises: @@ -152,10 +179,18 @@ def toco_convert(input_data, toco = _toco_flags_pb2.TocoFlags() toco.input_format = input_format toco.output_format = output_format - toco.drop_control_dependency = drop_control_dependency - model = _model_flags_pb2.ModelFlags() toco.inference_type = inference_type + if inference_input_type: + toco.inference_input_type = inference_input_type + toco.drop_control_dependency = drop_control_dependency + toco.reorder_across_fake_quant = reorder_across_fake_quant toco.allow_custom_ops = allow_custom_ops + if default_ranges_stats: + toco.default_ranges_min = default_ranges_stats[0] + toco.default_ranges_max = default_ranges_stats[1] + + model = _model_flags_pb2.ModelFlags() + model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): if input_tensor.dtype == _dtypes.float32: tflite_input_type = lite_constants.FLOAT @@ -163,6 +198,8 @@ def toco_convert(input_data, tflite_input_type = lite_constants.INT32 elif input_tensor.dtype == _dtypes.int64: tflite_input_type = lite_constants.INT64 + elif input_tensor.dtype == _dtypes.uint8: + tflite_input_type = lite_constants.QUANTIZED_UINT8 # TODO(aselle): Insert strings when they are available else: raise ValueError("Tensors %s not known type %r" % (input_tensor.name, diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index 54fec9d61f..b952a72aab 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -18,31 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants - - -def _write_and_flush_file(file_path, data_str): - """Writes data to file path. - - Args: - file_path: Full path of the file to store data in. - data_str: Data represented as a string. - - Returns: None. - """ - with gfile.Open(file_path, "wb") as data_file: - data_file.write(data_str) - data_file.flush() def _log_tensor_details(tensor_info): @@ -167,29 +151,10 @@ def _get_tensors(graph, signature_def_tensor_names=None, """ tensors = [] if user_tensor_names: - # Get the list of all of the tensors with and without the tensor index. - all_tensor_names = [ - tensor.name for op in graph.get_operations() for tensor in op.outputs - ] - all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names] - # Sort the tensor names. user_tensor_names = sorted(user_tensor_names) - # Get the tensors associated with the tensor names. - tensors = [] - invalid_tensors = [] - for name in user_tensor_names: - if name not in all_tensor_names_only: - invalid_tensors.append(name) - else: - idx = all_tensor_names_only.index(name) - tensors.append(graph.get_tensor_by_name(all_tensor_names[idx])) - - # Throw ValueError if any user input names are not valid tensors. - if invalid_tensors: - raise ValueError("Invalid tensors '{}' were found.".format( - ",".join(invalid_tensors))) + tensors = get_tensors_from_tensor_names(graph, user_tensor_names) elif signature_def_tensor_names: tensors = [ graph.get_tensor_by_name(name) @@ -204,6 +169,58 @@ def _get_tensors(graph, signature_def_tensor_names=None, return tensors +def get_tensors_from_tensor_names(graph, tensor_names): + """Gets the Tensors associated with the `tensor_names` in the provided graph. + + Args: + graph: TensorFlow Graph. + tensor_names: List of strings that represent names of tensors in the graph. + + Returns: + A list of Tensor objects in the same order the names are provided. + + Raises: + ValueError: + tensor_names contains an invalid tensor name. + """ + # Get the list of all of the tensors. + tensor_name_to_tensor = { + tensor_name(tensor): tensor for op in graph.get_operations() + for tensor in op.values() + } + + # Get the tensors associated with tensor_names. + tensors = [] + invalid_tensors = [] + for name in tensor_names: + tensor = tensor_name_to_tensor.get(name) + if tensor is None: + invalid_tensors.append(name) + else: + tensors.append(tensor) + + # Throw ValueError if any user input names are not valid tensors. + if invalid_tensors: + raise ValueError("Invalid tensors '{}' were found.".format( + ",".join(invalid_tensors))) + return tensors + + +def set_tensor_shapes(tensors, shapes): + """Sets Tensor shape for each tensor if the shape is defined. + + Args: + tensors: TensorFlow ops.Tensor. + shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + """ + if shapes: + for tensor in tensors: + shape = shapes.get(tensor.name) + if shape is not None: + tensor.set_shape(shapes[tensor.name]) + + def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, signature_key): """Converts a SavedModel to a frozen graph. @@ -211,15 +228,14 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, Args: saved_model_dir: SavedModel directory to convert. input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of + from SignatureDef when none are provided. + input_shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) + arrays from SignatureDef when none are provided. tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") + analyze. All tags in the tag set must be present. signature_key: Key identifying SignatureDef containing inputs and outputs. Returns: @@ -233,14 +249,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, signature_key is not in the MetaGraphDef. input_shapes does not match the length of input_arrays. input_arrays or output_arrays are not valid. - Unable to load Session. """ - # Set default values for inputs if they are set to None. - if signature_key is None: - signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if tag_set is None: - tag_set = set([tag_constants.SERVING]) - # Read SignatureDef. meta_graph = _get_meta_graph_def(saved_model_dir, tag_set) signature_def = _get_signature_def(meta_graph, signature_key) @@ -255,19 +264,10 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, # TODO(zhixianyan): Use TFLite supported Op list to filter outputs. in_tensors = _get_tensors(graph, inputs, input_arrays) out_tensors = _get_tensors(graph, outputs, output_arrays) - - # Gets fully defined tensor shape. - for tensor in in_tensors: - if (input_shapes and tensor.name in input_shapes and - input_shapes[tensor.name] is not None): - shape = input_shapes[tensor.name] - else: - shape = tensor.get_shape().as_list() - tensor.set_shape(shape) + set_tensor_shapes(in_tensors, input_shapes) output_names = [node.split(":")[0] for node in outputs] frozen_graph_def = tf_graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_names) return frozen_graph_def, in_tensors, out_tensors - raise ValueError("Unable to load Session.") diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index f69381d0e6..80e5dc6e46 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -41,9 +41,58 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import training as train +class TensorFunctionsTest(test_util.TensorFlowTestCase): + + def testGetTensorsValid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + tensors = convert_saved_model.get_tensors_from_tensor_names( + sess.graph, ["Placeholder"]) + self.assertEqual("Placeholder:0", tensors[0].name) + + def testGetTensorsInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + with self.assertRaises(ValueError) as error: + convert_saved_model.get_tensors_from_tensor_names(sess.graph, + ["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + def testSetTensorShapeValid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], + {"Placeholder:0": [5, 3, 5]}) + self.assertEqual([5, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeInvalid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], + {"invalid-input": [5, 3, 5]}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeEmpty(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + class FreezeSavedModelTest(test_util.TensorFlowTestCase): def _createSimpleSavedModel(self, shape): @@ -93,6 +142,10 @@ class FreezeSavedModelTest(test_util.TensorFlowTestCase): output_arrays=None, tag_set=None, signature_key=None): + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model( saved_model_dir=saved_model_dir, input_arrays=input_arrays, @@ -390,7 +443,7 @@ class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase): input_arrays=None, input_shapes=None, output_arrays=["Softmax"], - tag_set=None, + tag_set=set([tag_constants.SERVING]), signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) self.assertTrue(result) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index f7f2d40a02..253b5eadf3 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -33,15 +33,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from google.protobuf import text_format as _text_format +from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.lite.python.convert import toco_convert from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model +from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names +from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.framework.importer import import_graph_def from tensorflow.python.ops.variables import global_variables_initializer from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants @@ -55,26 +62,50 @@ class TocoConverter(object): Attributes: - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - (default FLOAT) - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT). (default TFLITE) - quantized_input_stats: The mean and std deviation of training data for each - input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`. - (default None) + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default {}) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) drop_control_dependency: Boolean indicating whether to drop control dependencies silently. This is due to TFLite not supporting control dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. (default False) Example usage: - # Converting a frozen graph. + # Converting a GraphDef from session. converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) + # Converting a GraphDef from file. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, input_arrays, output_arrays) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + # Converting a SavedModel. converter = lite.TocoConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() @@ -94,17 +125,17 @@ class TocoConverter(object): self._input_tensors = input_tensors self._output_tensors = output_tensors self.inference_type = constants.FLOAT + self.inference_input_type = None self.output_format = constants.TFLITE - self.quantized_input_stats = None + self.quantized_input_stats = {} + self.default_ranges_stats = None self.drop_control_dependency = True + self.reorder_across_fake_quant = False + self.change_concat_input_ranges = False self.allow_custom_ops = False @classmethod - def from_session(cls, - sess, - input_tensors, - output_tensors, - freeze_variables=False): + def from_session(cls, sess, input_tensors, output_tensors): """Creates a TocoConverter class from a TensorFlow Session. Args: @@ -112,56 +143,102 @@ class TocoConverter(object): input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - freeze_variables: Boolean indicating whether the variables need to be - converted into constants via the freeze_graph.py script. - (default False) Returns: TocoConverter class. """ + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + + @classmethod + def from_frozen_graph(cls, + graph_def_file, + input_arrays, + output_arrays, + input_shapes=None): + """Creates a TocoConverter class from a file containing a frozen GraphDef. + + Args: + graph_def_file: Full filepath of file containing TensorFlow GraphDef. + input_arrays: List of input tensors to freeze graph with. + output_arrays: List of output tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) - # Get GraphDef. - if freeze_variables: + Returns: + TocoConverter class. + + Raises: + ValueError: + Unable to parse input file. + The graph is not frozen. + input_arrays or output_arrays contains an invalid tensor name. + """ + with _session.Session() as sess: sess.run(global_variables_initializer()) - output_arrays = [tensor_name(tensor) for tensor in output_tensors] - graph_def = tf_graph_util.convert_variables_to_constants( - sess, sess.graph_def, output_arrays) - else: - graph_def = sess.graph_def - # Create TocoConverter class. - return cls(graph_def, input_tensors, output_tensors) + # Read GraphDef from file. + graph_def = _graph_pb2.GraphDef() + with open(graph_def_file, "rb") as f: + file_content = f.read() + try: + graph_def.ParseFromString(file_content) + except (_text_format.ParseError, DecodeError): + try: + print("Ignore 'tcmalloc: large alloc' warnings.") + _text_format.Merge(file_content, graph_def) + except (_text_format.ParseError, DecodeError): + raise ValueError( + "Unable to parse input file '{}'.".format(graph_def_file)) + sess.graph.as_default() + import_graph_def(graph_def, name="") + + # Get input and output tensors. + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + set_tensor_shapes(input_tensors, input_shapes) + + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py") + + # Create TocoConverter class. + return cls(sess.graph_def, input_tensors, output_tensors) @classmethod - def from_saved_model( - cls, - saved_model_dir, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): + def from_saved_model(cls, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): """Creates a TocoConverter class from a SavedModel. Args: saved_model_dir: SavedModel directory to convert. input_arrays: List of input tensors to freeze graph with. Uses input arrays from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). (default None) output_arrays: List of output tensors to freeze graph with. Uses output arrays from SignatureDef when none are provided. (default None) tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") + analyze. All tags in the tag set must be present. (default set("serve")) signature_key: Key identifying SignatureDef containing inputs and outputs. + (default DEFAULT_SERVING_SIGNATURE_DEF_KEY) Returns: TocoConverter class. """ if tag_set is None: tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, signature_key) @@ -189,16 +266,39 @@ class TocoConverter(object): elif shape[0] is None: self._set_batch_size(batch_size=1) + # Get quantization stats. Ensures there is one stat per name if the stats + # are specified. + if self.quantized_input_stats: + quantized_stats = [] + invalid_stats = [] + for tensor in self._input_tensors: + name = tensor_name(tensor) + if name in self.quantized_input_stats: + quantized_stats.append(self.quantized_input_stats[name]) + else: + invalid_stats.append(name) + + if invalid_stats: + raise ValueError("Quantization input stats are not available for input " + "tensors '{0}'.".format(",".join(invalid_stats))) + else: + quantized_stats = None + # Converts model. result = toco_convert( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, + inference_input_type=self.inference_input_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, - quantized_input_stats=self.quantized_input_stats, - drop_control_dependency=self.drop_control_dependency) + quantized_input_stats=quantized_stats, + default_ranges_stats=self.default_ranges_stats, + drop_control_dependency=self.drop_control_dependency, + reorder_across_fake_quant=self.reorder_across_fake_quant, + change_concat_input_ranges=self.change_concat_input_ranges, + allow_custom_ops=self.allow_custom_ops) return result def _set_batch_size(self, batch_size): @@ -212,3 +312,43 @@ class TocoConverter(object): shape = tensor.get_shape().as_list() shape[0] = batch_size tensor.set_shape(shape) + + +def _is_frozen_graph(sess): + """Determines if the graph is frozen. + + Determines if a graph has previously been frozen by checking for any + operations of type Variable*. If variables are found, the graph is not frozen. + + Args: + sess: TensorFlow Session. + + Returns: + Bool. + """ + for op in sess.graph.get_operations(): + if op.type.startswith("Variable"): + return False + return True + + +def _freeze_graph(sess, output_tensors): + """Returns a frozen GraphDef. + + Freezes a graph with Variables in it. Otherwise the existing GraphDef is + returned. + + Args: + sess: TensorFlow Session. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + Frozen GraphDef. + """ + if not _is_frozen_graph(sess): + sess.run(global_variables_initializer()) + output_arrays = [tensor_name(tensor) for tensor in output_tensors] + return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, + output_arrays) + else: + return sess.graph_def diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 2f3105f3e6..53d1878293 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -29,8 +29,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model +from tensorflow.python.training.training_util import write_graph class FromSessionTest(test_util.TensorFlowTestCase): @@ -65,16 +67,22 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertEqual((0., 0.), output_details[0]['quantization']) def testQuantization(self): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input') + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') out_tensor = array_ops.fake_quant_with_min_max_args( - in_tensor + in_tensor, min=0., max=1., name='output') + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') sess = session.Session() # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) converter.inference_type = lite_constants.QUANTIZED_UINT8 - converter.quantized_input_stats = [(0., 1.)] # mean, std_dev + converter.quantized_input_stats = { + 'inputA': (0., 1.), + 'inputB': (0., 1.) + } # mean, std_dev tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -83,13 +91,19 @@ class FromSessionTest(test_util.TensorFlowTestCase): interpreter.allocate_tensors() input_details = interpreter.get_input_details() - self.assertEqual(1, len(input_details)) - self.assertEqual('input', input_details[0]['name']) + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((1., 0.), input_details[0]['quantization']) # scale, zero_point + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.uint8, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((1., 0.), + input_details[1]['quantization']) # scale, zero_point + output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('output', output_details[0]['name']) @@ -97,6 +111,26 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + def testQuantizationInvalid(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'Quantization input stats are not available for input tensors ' + '\'inputB\'.', str(error.exception)) + def testBatchSizeInvalid(self): in_tensor = array_ops.placeholder( shape=[None, 16, 16, 3], dtype=dtypes.float32) @@ -152,8 +186,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): sess = session.Session() # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_session( - sess, [in_tensor], [out_tensor], freeze_variables=True) + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -187,6 +220,196 @@ class FromSessionTest(test_util.TensorFlowTestCase): graphviz_output = converter.convert() self.assertTrue(graphviz_output) + def testInferenceInputType(self): + in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_input_type = lite_constants.QUANTIZED_UINT8 + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + def testDefaultRangesStats(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev + converter.default_ranges_stats = (0, 6) # min, max + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + + +class FromFlatbufferFile(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFloatWithShapesArray(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, ['Placeholder'], ['add'], + input_shapes={'Placeholder': [1, 16, 16, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + var + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Ensure the graph with variables cannot be converted. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual('Please freeze the graph using freeze_graph.py', + str(error.exception)) + + def testPbtxt(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') + write_graph(sess.graph_def, '', graph_def_file, True) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testInvalidFile(self): + graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') + with gfile.Open(graph_def_file, 'wb') as temp_file: + temp_file.write('bad data') + temp_file.flush() + + # Attempts to convert the invalid model. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual( + 'Unable to parse input file \'{}\'.'.format(graph_def_file), + str(error.exception)) + class FromSavedModelTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py new file mode 100644 index 0000000000..337f05785e --- /dev/null +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -0,0 +1,324 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python command line interface for running TOCO.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.platform import app + + +def _parse_array(values): + if values: + return values.split(",") + + +def _parse_int_array(values): + if values: + return [int(val) for val in values.split(",")] + + +def _parse_set(values): + if values: + return set(values.split(",")) + + +def _get_toco_converter(flags): + """Makes a TocoConverter object based on the flags provided. + + Args: + flags: argparse.Namespace object containing TFLite flags. + + Returns: + TocoConverter object. + """ + # Parse input and output arrays. + input_arrays = _parse_array(flags.input_arrays) + input_shapes = None + if flags.input_shapes: + input_shapes_list = [ + _parse_int_array(shape) for shape in flags.input_shapes.split(":") + ] + input_shapes = dict(zip(input_arrays, input_shapes_list)) + output_arrays = _parse_array(flags.output_arrays) + + converter_kwargs = { + "input_arrays": input_arrays, + "input_shapes": input_shapes, + "output_arrays": output_arrays + } + + # Create TocoConverter. + if flags.graph_def_file: + converter_fn = lite.TocoConverter.from_frozen_graph + converter_kwargs["graph_def_file"] = flags.graph_def_file + elif flags.saved_model_dir: + converter_fn = lite.TocoConverter.from_saved_model + converter_kwargs["saved_model_dir"] = flags.saved_model_dir + converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) + converter_kwargs["signature_key"] = flags.saved_model_signature_key + + return converter_fn(**converter_kwargs) + + +def _convert_model(flags): + """Calls function to convert the TensorFlow model into a TFLite model. + + Args: + flags: argparse.Namespace object. + """ + # Create converter. + converter = _get_toco_converter(flags) + if flags.inference_type: + converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type) + if flags.inference_input_type: + converter.inference_input_type = _types_pb2.IODataType.Value( + flags.inference_input_type) + if flags.output_format: + converter.output_format = _toco_flags_pb2.FileFormat.Value( + flags.output_format) + + if flags.mean_values and flags.std_dev_values: + input_arrays = _parse_array(flags.input_arrays) + std_dev_values = _parse_int_array(flags.std_dev_values) + mean_values = _parse_int_array(flags.mean_values) + quant_stats = zip(mean_values, std_dev_values) + converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) + if flags.default_ranges_min and flags.default_ranges_max: + converter.default_ranges_stats = (flags.default_ranges_min, + flags.default_ranges_max) + + if flags.drop_control_dependency: + converter.drop_control_dependency = flags.drop_control_dependency + if flags.reorder_across_fake_quant: + converter.reorder_across_fake_quant = flags.reorder_across_fake_quant + if flags.change_concat_input_ranges: + converter.change_concat_input_ranges = flags.change_concat_input_ranges + if flags.allow_custom_ops: + converter.allow_custom_ops = flags.allow_custom_ops + + # Convert model. + output_data = converter.convert() + with open(flags.output_file, "wb") as f: + f.write(output_data) + + +def _check_flags(flags, unparsed): + """Checks the parsed and unparsed flags to ensure they are valid. + + Raises an error if previously support unparsed flags are found. Raises an + error for parsed flags that don't meet the required conditions. + + Args: + flags: argparse.Namespace object containing TFLite flags. + unparsed: List of unparsed flags. + + Raises: + ValueError: Invalid flags. + """ + + # Check unparsed flags for common mistakes based on previous TOCO. + def _get_message_unparsed(flag, orig_flag, new_flag): + if flag.startswith(orig_flag): + return "\n Use {0} instead of {1}".format(new_flag, orig_flag) + return "" + + if unparsed: + output = "" + for flag in unparsed: + output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") + output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") + raise ValueError(output) + + # Check that flags are valid. + if flags.graph_def_file and (not flags.input_arrays or + not flags.output_arrays): + raise ValueError("--input_arrays and --output_arrays are required with " + "--graph_def_file") + + if flags.input_shapes: + if not flags.input_arrays: + raise ValueError("--input_shapes must be used with --input_arrays") + if flags.input_shapes.count(":") != flags.input_arrays.count(","): + raise ValueError("--input_shapes and --input_arrays must have the same " + "number of items") + + if flags.std_dev_values or flags.mean_values: + if bool(flags.std_dev_values) != bool(flags.mean_values): + raise ValueError("--std_dev_values and --mean_values must be used " + "together") + if not flags.input_arrays: + raise ValueError("--std_dev_values and --mean_values must be used with " + "--input_arrays") + if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or + flags.std_dev_values.count(",") != flags.input_arrays.count(",")): + raise ValueError("--std_dev_values, --mean_values, and --input_arrays " + "must have the same number of items") + + if bool(flags.default_ranges_min) != bool(flags.default_ranges_max): + raise ValueError("--default_ranges_min and --default_ranges_max must be " + "used together") + + +def run_main(_): + """Main in toco_convert.py.""" + parser = argparse.ArgumentParser( + description=("Command line tool to run TensorFlow Lite Optimizing " + "Converter (TOCO).")) + + # Output file flag. + parser.add_argument( + "--output_file", + type=str, + help="Full filepath of the output file.", + required=True) + + # Input file flags. + input_file_group = parser.add_mutually_exclusive_group(required=True) + input_file_group.add_argument( + "--graph_def_file", + type=str, + help="Full filepath of file containing TensorFlow GraphDef.") + input_file_group.add_argument( + "--saved_model_dir", + type=str, + help="Full filepath of directory containing the SavedModel.") + + # Model format flags. + parser.add_argument( + "--output_format", + type=str, + choices=["TFLITE", "GRAPHVIZ_DOT"], + help="Output file format.") + parser.add_argument( + "--inference_type", + type=str, + choices=["FLOAT", "QUANTIZED_UINT8"], + help="Target data type of arrays in the output file.") + parser.add_argument( + "--inference_input_type", + type=str, + choices=["FLOAT", "QUANTIZED_UINT8"], + help=("Target data type of input arrays. Allows for a different type for " + "input arrays in the case of quantization.")) + + # Input and output arrays flags. + parser.add_argument( + "--input_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + parser.add_argument( + "--input_shapes", + type=str, + help="Shapes corresponding to --input_arrays, colon-separated.") + parser.add_argument( + "--output_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + + # SavedModel related flags. + parser.add_argument( + "--saved_model_tag_set", + type=str, + help=("Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags must be present. " + "(default \"serve\")")) + parser.add_argument( + "--saved_model_signature_key", + type=str, + help=("Key identifying the SignatureDef containing inputs and outputs. " + "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) + + # Quantization flags. + parser.add_argument( + "--std_dev_values", + type=str, + help=("Standard deviation of training data for each input tensor, " + "comma-separated. Used for quantization. (default None)")) + parser.add_argument( + "--mean_values", + type=str, + help=("Mean of training data for each input tensor, comma-separated. " + "Used for quantization. (default None)")) + parser.add_argument( + "--default_ranges_min", + type=int, + help=("Default value for min bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--default_ranges_max", + type=int, + help=("Default value for max bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + + # Graph manipulation flags. + parser.add_argument( + "--drop_control_dependency", + type=bool, + help=("Boolean indicating whether to drop control dependencies silently. " + "This is due to TensorFlow not supporting control dependencies. " + "(default True)")) + parser.add_argument( + "--reorder_across_fake_quant", + type=bool, + help=("Boolean indicating whether to reorder FakeQuant nodes in " + "unexpected locations. Used when the location of the FakeQuant " + "nodes is preventing graph transformations necessary to convert " + "the graph. Results in a graph that differs from the quantized " + "training graph, potentially causing differing arithmetic " + "behavior. (default False)")) + parser.add_argument( + "--change_concat_input_ranges", + type=bool, + help=("Boolean to change behavior of min/max ranges for inputs and " + "outputs of the concat operator for quantized models. Changes the " + "ranges of concat operator overlap when true. (default False)")) + parser.add_argument( + "--allow_custom_ops", + type=bool, + help=("Boolean indicating whether to allow custom operations. When false " + "any unknown operation is an error. When true, custom ops are " + "created for any op that is unknown. The developer will need to " + "provide these to the TensorFlow Lite runtime with a custom " + "resolver. (default False)")) + + tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) + try: + _check_flags(tflite_flags, unparsed) + except ValueError as e: + parser.print_usage() + file_name = os.path.basename(sys.argv[0]) + sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) + sys.exit(1) + _convert_model(tflite_flags) + + +def main(): + app.run(main=run_main, argv=sys.argv[:1]) + + +if __name__ == "__main__": + main() |