aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-05-31 18:29:32 -0700
committerGravatar Michael Case <mikecase@google.com>2018-05-31 18:29:32 -0700
commit7a9f1e6ec8ae53cd4321e819bd2343429a5ea9eb (patch)
tree430e7ffc7d8157f2e0b4bb6d814e2eb0d375d31b /tensorflow/contrib/lite/python
parente365deab1333005c8aa186632f160c1bfd4485f8 (diff)
parent2e272dbca6600991599e55a7ff7cfa668b8403aa (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/BUILD19
-rw-r--r--tensorflow/contrib/lite/python/convert.py63
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py118
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_test.py55
-rw-r--r--tensorflow/contrib/lite/python/lite.py218
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py241
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py324
7 files changed, 909 insertions, 129 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index a40e512045..7e6ff6c0a8 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -36,6 +36,16 @@ py_test(
],
)
+py_binary(
+ name = "tflite_convert",
+ srcs = ["tflite_convert.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lite",
+ ],
+)
+
py_library(
name = "lite",
srcs = ["lite.py"],
@@ -125,6 +135,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":convert",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
"//tensorflow/python:platform",
@@ -164,11 +175,3 @@ py_test(
"//tensorflow/python/saved_model",
],
)
-
-# Transitive dependencies of this target will be included in the pip package.
-py_library(
- name = "tf_lite_py_pip",
- deps = [
- ":convert_saved_model",
- ],
-)
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index c0926d2f33..0819475240 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -115,11 +115,15 @@ def toco_convert(input_data,
input_tensors,
output_tensors,
inference_type=lite_constants.FLOAT,
+ inference_input_type=None,
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
output_format=lite_constants.TFLITE,
quantized_input_stats=None,
+ default_ranges_stats=None,
drop_control_dependency=True,
- allow_custom_ops=False):
+ reorder_across_fake_quant=False,
+ allow_custom_ops=False,
+ change_concat_input_ranges=False):
"""Convert a model using TOCO from `input_format` to `output_format`.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -130,18 +134,41 @@ def toco_convert(input_data,
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT)
- quantized_input_stats: For each member of input_tensors the mean and
- std deviation of training data. Only needed if `inference_type` is
- `QUANTIZED_UINT8`.
- drop_control_dependency: Drops control dependencies silently. This is due
- to tf lite not supporting control dependencies.
+ inference_type: Target data type of arrays in the output file. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of input arrays. Allows for a
+ different type for input arrays in the case of quantization. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ input_format: Type of data to read Currently must be
+ `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
+ output_format: Output file format. Currently must be `{TFLITE,
+ GRAPHVIZ_DOT}`. (default TFLITE)
+ quantized_input_stats: Dict of strings representing input tensor names
+ mapped to tuple of integers representing the mean and standard deviation
+ of the training data (e.g., {"foo" : (0., 1.)}). Only need if
+ `inference_type` is `QUANTIZED_UINT8`. (default None)
+ default_ranges_stats: Tuple of integers representing (min, max) range values
+ for all arrays without a specified range. Intended for experimenting with
+ quantization via "dummy quantization". (default None)
+ drop_control_dependency: Boolean indicating whether to drop control
+ dependencies silently. This is due to TFLite not supporting control
+ dependencies. (default True)
+ reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
+ nodes in unexpected locations. Used when the location of the FakeQuant
+ nodes is preventing graph transformations necessary to convert the graph.
+ Results in a graph that differs from the quantized training graph,
+ potentially causing differing arithmetic behavior. (default False)
+ change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+ inputs and outputs of the concat operator for quantized models. Changes
+ the ranges of concat operator overlap when true. (default False)
+ allow_custom_ops: Boolean indicating whether to allow custom operations.
+ When false any unknown operation is an error. When true, custom ops are
+ created for any op that is unknown. The developer will need to provide
+ these to the TensorFlow Lite runtime with a custom resolver.
+ (default False)
Returns:
- The converted data. For example if tflite was the destination, then
+ The converted data. For example if TFLite was the destination, then
this will be a tflite flatbuffer in a bytes array.
Raises:
@@ -152,10 +179,18 @@ def toco_convert(input_data,
toco = _toco_flags_pb2.TocoFlags()
toco.input_format = input_format
toco.output_format = output_format
- toco.drop_control_dependency = drop_control_dependency
- model = _model_flags_pb2.ModelFlags()
toco.inference_type = inference_type
+ if inference_input_type:
+ toco.inference_input_type = inference_input_type
+ toco.drop_control_dependency = drop_control_dependency
+ toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
+ if default_ranges_stats:
+ toco.default_ranges_min = default_ranges_stats[0]
+ toco.default_ranges_max = default_ranges_stats[1]
+
+ model = _model_flags_pb2.ModelFlags()
+ model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
if input_tensor.dtype == _dtypes.float32:
tflite_input_type = lite_constants.FLOAT
@@ -163,6 +198,8 @@ def toco_convert(input_data,
tflite_input_type = lite_constants.INT32
elif input_tensor.dtype == _dtypes.int64:
tflite_input_type = lite_constants.INT64
+ elif input_tensor.dtype == _dtypes.uint8:
+ tflite_input_type = lite_constants.QUANTIZED_UINT8
# TODO(aselle): Insert strings when they are available
else:
raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 54fec9d61f..b952a72aab 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -18,31 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
-from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import tag_constants
-
-
-def _write_and_flush_file(file_path, data_str):
- """Writes data to file path.
-
- Args:
- file_path: Full path of the file to store data in.
- data_str: Data represented as a string.
-
- Returns: None.
- """
- with gfile.Open(file_path, "wb") as data_file:
- data_file.write(data_str)
- data_file.flush()
def _log_tensor_details(tensor_info):
@@ -167,29 +151,10 @@ def _get_tensors(graph, signature_def_tensor_names=None,
"""
tensors = []
if user_tensor_names:
- # Get the list of all of the tensors with and without the tensor index.
- all_tensor_names = [
- tensor.name for op in graph.get_operations() for tensor in op.outputs
- ]
- all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names]
-
# Sort the tensor names.
user_tensor_names = sorted(user_tensor_names)
- # Get the tensors associated with the tensor names.
- tensors = []
- invalid_tensors = []
- for name in user_tensor_names:
- if name not in all_tensor_names_only:
- invalid_tensors.append(name)
- else:
- idx = all_tensor_names_only.index(name)
- tensors.append(graph.get_tensor_by_name(all_tensor_names[idx]))
-
- # Throw ValueError if any user input names are not valid tensors.
- if invalid_tensors:
- raise ValueError("Invalid tensors '{}' were found.".format(
- ",".join(invalid_tensors)))
+ tensors = get_tensors_from_tensor_names(graph, user_tensor_names)
elif signature_def_tensor_names:
tensors = [
graph.get_tensor_by_name(name)
@@ -204,6 +169,58 @@ def _get_tensors(graph, signature_def_tensor_names=None,
return tensors
+def get_tensors_from_tensor_names(graph, tensor_names):
+ """Gets the Tensors associated with the `tensor_names` in the provided graph.
+
+ Args:
+ graph: TensorFlow Graph.
+ tensor_names: List of strings that represent names of tensors in the graph.
+
+ Returns:
+ A list of Tensor objects in the same order the names are provided.
+
+ Raises:
+ ValueError:
+ tensor_names contains an invalid tensor name.
+ """
+ # Get the list of all of the tensors.
+ tensor_name_to_tensor = {
+ tensor_name(tensor): tensor for op in graph.get_operations()
+ for tensor in op.values()
+ }
+
+ # Get the tensors associated with tensor_names.
+ tensors = []
+ invalid_tensors = []
+ for name in tensor_names:
+ tensor = tensor_name_to_tensor.get(name)
+ if tensor is None:
+ invalid_tensors.append(name)
+ else:
+ tensors.append(tensor)
+
+ # Throw ValueError if any user input names are not valid tensors.
+ if invalid_tensors:
+ raise ValueError("Invalid tensors '{}' were found.".format(
+ ",".join(invalid_tensors)))
+ return tensors
+
+
+def set_tensor_shapes(tensors, shapes):
+ """Sets Tensor shape for each tensor if the shape is defined.
+
+ Args:
+ tensors: TensorFlow ops.Tensor.
+ shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ """
+ if shapes:
+ for tensor in tensors:
+ shape = shapes.get(tensor.name)
+ if shape is not None:
+ tensor.set_shape(shapes[tensor.name])
+
+
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
@@ -211,15 +228,14 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input arrays
- from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
+ from SignatureDef when none are provided.
+ input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" : None}).
- (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
- arrays from SignatureDef when none are provided. (default None)
+ arrays from SignatureDef when none are provided.
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
+ analyze. All tags in the tag set must be present.
signature_key: Key identifying SignatureDef containing inputs and outputs.
Returns:
@@ -233,14 +249,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_key is not in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
- Unable to load Session.
"""
- # Set default values for inputs if they are set to None.
- if signature_key is None:
- signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- if tag_set is None:
- tag_set = set([tag_constants.SERVING])
-
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
signature_def = _get_signature_def(meta_graph, signature_key)
@@ -255,19 +264,10 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
# TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
-
- # Gets fully defined tensor shape.
- for tensor in in_tensors:
- if (input_shapes and tensor.name in input_shapes and
- input_shapes[tensor.name] is not None):
- shape = input_shapes[tensor.name]
- else:
- shape = tensor.get_shape().as_list()
- tensor.set_shape(shape)
+ set_tensor_shapes(in_tensors, input_shapes)
output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
return frozen_graph_def, in_tensors, out_tensors
- raise ValueError("Unable to load Session.")
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py
index f69381d0e6..80e5dc6e46 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model_test.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py
@@ -41,9 +41,58 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import training as train
+class TensorFunctionsTest(test_util.TensorFlowTestCase):
+
+ def testGetTensorsValid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ tensors = convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, ["Placeholder"])
+ self.assertEqual("Placeholder:0", tensors[0].name)
+
+ def testGetTensorsInvalid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ with self.assertRaises(ValueError) as error:
+ convert_saved_model.get_tensors_from_tensor_names(sess.graph,
+ ["invalid-input"])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ def testSetTensorShapeValid(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor],
+ {"Placeholder:0": [5, 3, 5]})
+ self.assertEqual([5, 3, 5], tensor.shape.as_list())
+
+ def testSetTensorShapeInvalid(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor],
+ {"invalid-input": [5, 3, 5]})
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ def testSetTensorShapeEmpty(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor], {})
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+
class FreezeSavedModelTest(test_util.TensorFlowTestCase):
def _createSimpleSavedModel(self, shape):
@@ -93,6 +142,10 @@ class FreezeSavedModelTest(test_util.TensorFlowTestCase):
output_arrays=None,
tag_set=None,
signature_key=None):
+ if tag_set is None:
+ tag_set = set([tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model(
saved_model_dir=saved_model_dir,
input_arrays=input_arrays,
@@ -390,7 +443,7 @@ class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
input_arrays=None,
input_shapes=None,
output_arrays=["Softmax"],
- tag_set=None,
+ tag_set=set([tag_constants.SERVING]),
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
self.assertTrue(result)
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index f7f2d40a02..253b5eadf3 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -33,15 +33,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from google.protobuf import text_format as _text_format
+from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model
+from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names
+from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes
from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as tf_graph_util
+from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.ops.variables import global_variables_initializer
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
@@ -55,26 +62,50 @@ class TocoConverter(object):
Attributes:
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- (default FLOAT)
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT). (default TFLITE)
- quantized_input_stats: The mean and std deviation of training data for each
- input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`.
- (default None)
+ inference_type: Target data type of arrays in the output file. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of input arrays. Allows for a
+ different type for input arrays in the case of quantization. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ output_format: Output file format. Currently must be `{TFLITE,
+ GRAPHVIZ_DOT}`. (default TFLITE)
+ quantized_input_stats: Dict of strings representing input tensor names
+ mapped to tuple of integers representing the mean and standard deviation
+ of the training data (e.g., {"foo" : (0., 1.)}). Only need if
+ `inference_type` is `QUANTIZED_UINT8`. (default {})
+ default_ranges_stats: Tuple of integers representing (min, max) range values
+ for all arrays without a specified range. Intended for experimenting with
+ quantization via "dummy quantization". (default None)
drop_control_dependency: Boolean indicating whether to drop control
dependencies silently. This is due to TFLite not supporting control
dependencies. (default True)
+ reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
+ nodes in unexpected locations. Used when the location of the FakeQuant
+ nodes is preventing graph transformations necessary to convert the graph.
+ Results in a graph that differs from the quantized training graph,
+ potentially causing differing arithmetic behavior. (default False)
+ change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+ inputs and outputs of the concat operator for quantized models. Changes
+ the ranges of concat operator overlap when true. (default False)
allow_custom_ops: Boolean indicating whether to allow custom operations.
+ When false any unknown operation is an error. When true, custom ops are
+ created for any op that is unknown. The developer will need to provide
+ these to the TensorFlow Lite runtime with a custom resolver.
(default False)
Example usage:
- # Converting a frozen graph.
+ # Converting a GraphDef from session.
converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
+ # Converting a GraphDef from file.
+ converter = lite.TocoConverter.from_frozen_graph(
+ graph_def_file, input_arrays, output_arrays)
+ tflite_model = converter.convert()
+ open("converted_model.tflite", "wb").write(tflite_model)
+
# Converting a SavedModel.
converter = lite.TocoConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
@@ -94,17 +125,17 @@ class TocoConverter(object):
self._input_tensors = input_tensors
self._output_tensors = output_tensors
self.inference_type = constants.FLOAT
+ self.inference_input_type = None
self.output_format = constants.TFLITE
- self.quantized_input_stats = None
+ self.quantized_input_stats = {}
+ self.default_ranges_stats = None
self.drop_control_dependency = True
+ self.reorder_across_fake_quant = False
+ self.change_concat_input_ranges = False
self.allow_custom_ops = False
@classmethod
- def from_session(cls,
- sess,
- input_tensors,
- output_tensors,
- freeze_variables=False):
+ def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session.
Args:
@@ -112,56 +143,102 @@ class TocoConverter(object):
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
- freeze_variables: Boolean indicating whether the variables need to be
- converted into constants via the freeze_graph.py script.
- (default False)
Returns:
TocoConverter class.
"""
+ graph_def = _freeze_graph(sess, output_tensors)
+ return cls(graph_def, input_tensors, output_tensors)
+
+ @classmethod
+ def from_frozen_graph(cls,
+ graph_def_file,
+ input_arrays,
+ output_arrays,
+ input_shapes=None):
+ """Creates a TocoConverter class from a file containing a frozen GraphDef.
+
+ Args:
+ graph_def_file: Full filepath of file containing TensorFlow GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" :
+ None}). (default None)
- # Get GraphDef.
- if freeze_variables:
+ Returns:
+ TocoConverter class.
+
+ Raises:
+ ValueError:
+ Unable to parse input file.
+ The graph is not frozen.
+ input_arrays or output_arrays contains an invalid tensor name.
+ """
+ with _session.Session() as sess:
sess.run(global_variables_initializer())
- output_arrays = [tensor_name(tensor) for tensor in output_tensors]
- graph_def = tf_graph_util.convert_variables_to_constants(
- sess, sess.graph_def, output_arrays)
- else:
- graph_def = sess.graph_def
- # Create TocoConverter class.
- return cls(graph_def, input_tensors, output_tensors)
+ # Read GraphDef from file.
+ graph_def = _graph_pb2.GraphDef()
+ with open(graph_def_file, "rb") as f:
+ file_content = f.read()
+ try:
+ graph_def.ParseFromString(file_content)
+ except (_text_format.ParseError, DecodeError):
+ try:
+ print("Ignore 'tcmalloc: large alloc' warnings.")
+ _text_format.Merge(file_content, graph_def)
+ except (_text_format.ParseError, DecodeError):
+ raise ValueError(
+ "Unable to parse input file '{}'.".format(graph_def_file))
+ sess.graph.as_default()
+ import_graph_def(graph_def, name="")
+
+ # Get input and output tensors.
+ input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+ output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+ set_tensor_shapes(input_tensors, input_shapes)
+
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py")
+
+ # Create TocoConverter class.
+ return cls(sess.graph_def, input_tensors, output_tensors)
@classmethod
- def from_saved_model(
- cls,
- saved_model_dir,
- input_arrays=None,
- input_shapes=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
+ def from_saved_model(cls,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=None):
"""Creates a TocoConverter class from a SavedModel.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input
arrays from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
- integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
+ analyze. All tags in the tag set must be present. (default set("serve"))
signature_key: Key identifying SignatureDef containing inputs and outputs.
+ (default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
TocoConverter class.
"""
if tag_set is None:
tag_set = set([tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key)
@@ -189,16 +266,39 @@ class TocoConverter(object):
elif shape[0] is None:
self._set_batch_size(batch_size=1)
+ # Get quantization stats. Ensures there is one stat per name if the stats
+ # are specified.
+ if self.quantized_input_stats:
+ quantized_stats = []
+ invalid_stats = []
+ for tensor in self._input_tensors:
+ name = tensor_name(tensor)
+ if name in self.quantized_input_stats:
+ quantized_stats.append(self.quantized_input_stats[name])
+ else:
+ invalid_stats.append(name)
+
+ if invalid_stats:
+ raise ValueError("Quantization input stats are not available for input "
+ "tensors '{0}'.".format(",".join(invalid_stats)))
+ else:
+ quantized_stats = None
+
# Converts model.
result = toco_convert(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
inference_type=self.inference_type,
+ inference_input_type=self.inference_input_type,
input_format=constants.TENSORFLOW_GRAPHDEF,
output_format=self.output_format,
- quantized_input_stats=self.quantized_input_stats,
- drop_control_dependency=self.drop_control_dependency)
+ quantized_input_stats=quantized_stats,
+ default_ranges_stats=self.default_ranges_stats,
+ drop_control_dependency=self.drop_control_dependency,
+ reorder_across_fake_quant=self.reorder_across_fake_quant,
+ change_concat_input_ranges=self.change_concat_input_ranges,
+ allow_custom_ops=self.allow_custom_ops)
return result
def _set_batch_size(self, batch_size):
@@ -212,3 +312,43 @@ class TocoConverter(object):
shape = tensor.get_shape().as_list()
shape[0] = batch_size
tensor.set_shape(shape)
+
+
+def _is_frozen_graph(sess):
+ """Determines if the graph is frozen.
+
+ Determines if a graph has previously been frozen by checking for any
+ operations of type Variable*. If variables are found, the graph is not frozen.
+
+ Args:
+ sess: TensorFlow Session.
+
+ Returns:
+ Bool.
+ """
+ for op in sess.graph.get_operations():
+ if op.type.startswith("Variable"):
+ return False
+ return True
+
+
+def _freeze_graph(sess, output_tensors):
+ """Returns a frozen GraphDef.
+
+ Freezes a graph with Variables in it. Otherwise the existing GraphDef is
+ returned.
+
+ Args:
+ sess: TensorFlow Session.
+ output_tensors: List of output tensors (only .name is used from this).
+
+ Returns:
+ Frozen GraphDef.
+ """
+ if not _is_frozen_graph(sess):
+ sess.run(global_variables_initializer())
+ output_arrays = [tensor_name(tensor) for tensor in output_tensors]
+ return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def,
+ output_arrays)
+ else:
+ return sess.graph_def
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 2f3105f3e6..53d1878293 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -29,8 +29,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
class FromSessionTest(test_util.TensorFlowTestCase):
@@ -65,16 +67,22 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertEqual((0., 0.), output_details[0]['quantization'])
def testQuantization(self):
- in_tensor = array_ops.placeholder(
- shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
out_tensor = array_ops.fake_quant_with_min_max_args(
- in_tensor + in_tensor, min=0., max=1., name='output')
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
- converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ converter.quantized_input_stats = {
+ 'inputA': (0., 1.),
+ 'inputB': (0., 1.)
+ } # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -83,13 +91,19 @@ class FromSessionTest(test_util.TensorFlowTestCase):
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
- self.assertEqual(1, len(input_details))
- self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('inputA', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((1., 0.),
input_details[0]['quantization']) # scale, zero_point
+ self.assertEqual('inputB', input_details[1]['name'])
+ self.assertEqual(np.uint8, input_details[1]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+ self.assertEqual((1., 0.),
+ input_details[1]['quantization']) # scale, zero_point
+
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
@@ -97,6 +111,26 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+ def testQuantizationInvalid(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+ out_tensor = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
+ with self.assertRaises(ValueError) as error:
+ converter.convert()
+ self.assertEqual(
+ 'Quantization input stats are not available for input tensors '
+ '\'inputB\'.', str(error.exception))
+
def testBatchSizeInvalid(self):
in_tensor = array_ops.placeholder(
shape=[None, 16, 16, 3], dtype=dtypes.float32)
@@ -152,8 +186,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
- sess, [in_tensor], [out_tensor], freeze_variables=True)
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -187,6 +220,196 @@ class FromSessionTest(test_util.TensorFlowTestCase):
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
+ def testInferenceInputType(self):
+ in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ def testDefaultRangesStats(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
+ converter.default_ranges_stats = (0, 6) # min, max
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((1., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+
+
+class FromFlatbufferFile(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFloatWithShapesArray(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(
+ graph_def_file, ['Placeholder'], ['add'],
+ input_shapes={'Placeholder': [1, 16, 16, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+
+ def testFreezeGraph(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ var = variable_scope.get_variable(
+ 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + var
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Ensure the graph with variables cannot be converted.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual('Please freeze the graph using freeze_graph.py',
+ str(error.exception))
+
+ def testPbtxt(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
+ write_graph(sess.graph_def, '', graph_def_file, True)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testInvalidFile(self):
+ graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
+ with gfile.Open(graph_def_file, 'wb') as temp_file:
+ temp_file.write('bad data')
+ temp_file.flush()
+
+ # Attempts to convert the invalid model.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual(
+ 'Unable to parse input file \'{}\'.'.format(graph_def_file),
+ str(error.exception))
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
new file mode 100644
index 0000000000..337f05785e
--- /dev/null
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -0,0 +1,324 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python command line interface for running TOCO."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
+from tensorflow.python.platform import app
+
+
+def _parse_array(values):
+ if values:
+ return values.split(",")
+
+
+def _parse_int_array(values):
+ if values:
+ return [int(val) for val in values.split(",")]
+
+
+def _parse_set(values):
+ if values:
+ return set(values.split(","))
+
+
+def _get_toco_converter(flags):
+ """Makes a TocoConverter object based on the flags provided.
+
+ Args:
+ flags: argparse.Namespace object containing TFLite flags.
+
+ Returns:
+ TocoConverter object.
+ """
+ # Parse input and output arrays.
+ input_arrays = _parse_array(flags.input_arrays)
+ input_shapes = None
+ if flags.input_shapes:
+ input_shapes_list = [
+ _parse_int_array(shape) for shape in flags.input_shapes.split(":")
+ ]
+ input_shapes = dict(zip(input_arrays, input_shapes_list))
+ output_arrays = _parse_array(flags.output_arrays)
+
+ converter_kwargs = {
+ "input_arrays": input_arrays,
+ "input_shapes": input_shapes,
+ "output_arrays": output_arrays
+ }
+
+ # Create TocoConverter.
+ if flags.graph_def_file:
+ converter_fn = lite.TocoConverter.from_frozen_graph
+ converter_kwargs["graph_def_file"] = flags.graph_def_file
+ elif flags.saved_model_dir:
+ converter_fn = lite.TocoConverter.from_saved_model
+ converter_kwargs["saved_model_dir"] = flags.saved_model_dir
+ converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
+ converter_kwargs["signature_key"] = flags.saved_model_signature_key
+
+ return converter_fn(**converter_kwargs)
+
+
+def _convert_model(flags):
+ """Calls function to convert the TensorFlow model into a TFLite model.
+
+ Args:
+ flags: argparse.Namespace object.
+ """
+ # Create converter.
+ converter = _get_toco_converter(flags)
+ if flags.inference_type:
+ converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type)
+ if flags.inference_input_type:
+ converter.inference_input_type = _types_pb2.IODataType.Value(
+ flags.inference_input_type)
+ if flags.output_format:
+ converter.output_format = _toco_flags_pb2.FileFormat.Value(
+ flags.output_format)
+
+ if flags.mean_values and flags.std_dev_values:
+ input_arrays = _parse_array(flags.input_arrays)
+ std_dev_values = _parse_int_array(flags.std_dev_values)
+ mean_values = _parse_int_array(flags.mean_values)
+ quant_stats = zip(mean_values, std_dev_values)
+ converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
+ if flags.default_ranges_min and flags.default_ranges_max:
+ converter.default_ranges_stats = (flags.default_ranges_min,
+ flags.default_ranges_max)
+
+ if flags.drop_control_dependency:
+ converter.drop_control_dependency = flags.drop_control_dependency
+ if flags.reorder_across_fake_quant:
+ converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
+ if flags.change_concat_input_ranges:
+ converter.change_concat_input_ranges = flags.change_concat_input_ranges
+ if flags.allow_custom_ops:
+ converter.allow_custom_ops = flags.allow_custom_ops
+
+ # Convert model.
+ output_data = converter.convert()
+ with open(flags.output_file, "wb") as f:
+ f.write(output_data)
+
+
+def _check_flags(flags, unparsed):
+ """Checks the parsed and unparsed flags to ensure they are valid.
+
+ Raises an error if previously support unparsed flags are found. Raises an
+ error for parsed flags that don't meet the required conditions.
+
+ Args:
+ flags: argparse.Namespace object containing TFLite flags.
+ unparsed: List of unparsed flags.
+
+ Raises:
+ ValueError: Invalid flags.
+ """
+
+ # Check unparsed flags for common mistakes based on previous TOCO.
+ def _get_message_unparsed(flag, orig_flag, new_flag):
+ if flag.startswith(orig_flag):
+ return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
+ return ""
+
+ if unparsed:
+ output = ""
+ for flag in unparsed:
+ output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
+ output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
+ output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
+ raise ValueError(output)
+
+ # Check that flags are valid.
+ if flags.graph_def_file and (not flags.input_arrays or
+ not flags.output_arrays):
+ raise ValueError("--input_arrays and --output_arrays are required with "
+ "--graph_def_file")
+
+ if flags.input_shapes:
+ if not flags.input_arrays:
+ raise ValueError("--input_shapes must be used with --input_arrays")
+ if flags.input_shapes.count(":") != flags.input_arrays.count(","):
+ raise ValueError("--input_shapes and --input_arrays must have the same "
+ "number of items")
+
+ if flags.std_dev_values or flags.mean_values:
+ if bool(flags.std_dev_values) != bool(flags.mean_values):
+ raise ValueError("--std_dev_values and --mean_values must be used "
+ "together")
+ if not flags.input_arrays:
+ raise ValueError("--std_dev_values and --mean_values must be used with "
+ "--input_arrays")
+ if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or
+ flags.std_dev_values.count(",") != flags.input_arrays.count(",")):
+ raise ValueError("--std_dev_values, --mean_values, and --input_arrays "
+ "must have the same number of items")
+
+ if bool(flags.default_ranges_min) != bool(flags.default_ranges_max):
+ raise ValueError("--default_ranges_min and --default_ranges_max must be "
+ "used together")
+
+
+def run_main(_):
+ """Main in toco_convert.py."""
+ parser = argparse.ArgumentParser(
+ description=("Command line tool to run TensorFlow Lite Optimizing "
+ "Converter (TOCO)."))
+
+ # Output file flag.
+ parser.add_argument(
+ "--output_file",
+ type=str,
+ help="Full filepath of the output file.",
+ required=True)
+
+ # Input file flags.
+ input_file_group = parser.add_mutually_exclusive_group(required=True)
+ input_file_group.add_argument(
+ "--graph_def_file",
+ type=str,
+ help="Full filepath of file containing TensorFlow GraphDef.")
+ input_file_group.add_argument(
+ "--saved_model_dir",
+ type=str,
+ help="Full filepath of directory containing the SavedModel.")
+
+ # Model format flags.
+ parser.add_argument(
+ "--output_format",
+ type=str,
+ choices=["TFLITE", "GRAPHVIZ_DOT"],
+ help="Output file format.")
+ parser.add_argument(
+ "--inference_type",
+ type=str,
+ choices=["FLOAT", "QUANTIZED_UINT8"],
+ help="Target data type of arrays in the output file.")
+ parser.add_argument(
+ "--inference_input_type",
+ type=str,
+ choices=["FLOAT", "QUANTIZED_UINT8"],
+ help=("Target data type of input arrays. Allows for a different type for "
+ "input arrays in the case of quantization."))
+
+ # Input and output arrays flags.
+ parser.add_argument(
+ "--input_arrays",
+ type=str,
+ help="Names of the output arrays, comma-separated.")
+ parser.add_argument(
+ "--input_shapes",
+ type=str,
+ help="Shapes corresponding to --input_arrays, colon-separated.")
+ parser.add_argument(
+ "--output_arrays",
+ type=str,
+ help="Names of the output arrays, comma-separated.")
+
+ # SavedModel related flags.
+ parser.add_argument(
+ "--saved_model_tag_set",
+ type=str,
+ help=("Comma-separated set of tags identifying the MetaGraphDef within "
+ "the SavedModel to analyze. All tags must be present. "
+ "(default \"serve\")"))
+ parser.add_argument(
+ "--saved_model_signature_key",
+ type=str,
+ help=("Key identifying the SignatureDef containing inputs and outputs. "
+ "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
+
+ # Quantization flags.
+ parser.add_argument(
+ "--std_dev_values",
+ type=str,
+ help=("Standard deviation of training data for each input tensor, "
+ "comma-separated. Used for quantization. (default None)"))
+ parser.add_argument(
+ "--mean_values",
+ type=str,
+ help=("Mean of training data for each input tensor, comma-separated. "
+ "Used for quantization. (default None)"))
+ parser.add_argument(
+ "--default_ranges_min",
+ type=int,
+ help=("Default value for min bound of min/max range values used for all "
+ "arrays without a specified range, Intended for experimenting with "
+ "quantization via \"dummy quantization\". (default None)"))
+ parser.add_argument(
+ "--default_ranges_max",
+ type=int,
+ help=("Default value for max bound of min/max range values used for all "
+ "arrays without a specified range, Intended for experimenting with "
+ "quantization via \"dummy quantization\". (default None)"))
+
+ # Graph manipulation flags.
+ parser.add_argument(
+ "--drop_control_dependency",
+ type=bool,
+ help=("Boolean indicating whether to drop control dependencies silently. "
+ "This is due to TensorFlow not supporting control dependencies. "
+ "(default True)"))
+ parser.add_argument(
+ "--reorder_across_fake_quant",
+ type=bool,
+ help=("Boolean indicating whether to reorder FakeQuant nodes in "
+ "unexpected locations. Used when the location of the FakeQuant "
+ "nodes is preventing graph transformations necessary to convert "
+ "the graph. Results in a graph that differs from the quantized "
+ "training graph, potentially causing differing arithmetic "
+ "behavior. (default False)"))
+ parser.add_argument(
+ "--change_concat_input_ranges",
+ type=bool,
+ help=("Boolean to change behavior of min/max ranges for inputs and "
+ "outputs of the concat operator for quantized models. Changes the "
+ "ranges of concat operator overlap when true. (default False)"))
+ parser.add_argument(
+ "--allow_custom_ops",
+ type=bool,
+ help=("Boolean indicating whether to allow custom operations. When false "
+ "any unknown operation is an error. When true, custom ops are "
+ "created for any op that is unknown. The developer will need to "
+ "provide these to the TensorFlow Lite runtime with a custom "
+ "resolver. (default False)"))
+
+ tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
+ try:
+ _check_flags(tflite_flags, unparsed)
+ except ValueError as e:
+ parser.print_usage()
+ file_name = os.path.basename(sys.argv[0])
+ sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
+ sys.exit(1)
+ _convert_model(tflite_flags)
+
+
+def main():
+ app.run(main=run_main, argv=sys.argv[:1])
+
+
+if __name__ == "__main__":
+ main()