aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-04-23 17:10:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 17:12:39 -0700
commit771f7b46d631fa510658685d1b84ffbb22ffcd55 (patch)
tree47d4f9a79eed86b926c09f0bbfc4180ea588bb3b
parenta36e6edab33c7a5bef2f911d4d7bb88ffc8c7de6 (diff)
Improve TOCO SavedModel support.
PiperOrigin-RevId: 194009891
-rw-r--r--tensorflow/contrib/lite/python/BUILD45
-rw-r--r--tensorflow/contrib/lite/python/convert.py187
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py387
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_test.py172
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py106
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py (renamed from tensorflow/contrib/lite/python/lite_test.py)41
-rw-r--r--tensorflow/contrib/lite/python/lite.py204
-rw-r--r--tensorflow/contrib/lite/python/lite_constants.py53
8 files changed, 828 insertions, 367 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 926896d609..e6dcc7aa09 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -39,16 +39,35 @@ py_test(
py_library(
name = "lite",
srcs = ["lite.py"],
- # data = [
- # "//tensorflow/contrib/lite/toco/python:toco_from_protos",
- # ],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":convert",
+ ":convert_saved_model",
":op_hint",
+ ],
+)
+
+py_library(
+ name = "lite_constants",
+ srcs = ["lite_constants.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_py",
+ ],
+)
+
+py_library(
+ name = "convert",
+ srcs = ["convert.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lite_constants",
"//tensorflow/contrib/lite/toco:model_flags_proto_py",
"//tensorflow/contrib/lite/toco:toco_flags_proto_py",
"//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco",
+ "//tensorflow/contrib/lite/toco/python:toco_from_protos",
"//tensorflow/python:platform",
],
)
@@ -66,15 +85,15 @@ py_library(
)
py_test(
- name = "lite_test",
- srcs = ["lite_test.py"],
+ name = "convert_test",
+ srcs = ["convert_test.py"],
srcs_version = "PY2AND3",
tags = [
"no-internal-py3",
"no_oss",
],
deps = [
- ":lite",
+ ":convert",
":op_hint",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -84,13 +103,14 @@ py_test(
],
)
-py_binary(
+py_library(
name = "convert_saved_model",
srcs = ["convert_saved_model.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":lite",
+ ":convert",
+ ":lite_constants",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
"//tensorflow/python/tools:freeze_graph_lib",
@@ -130,6 +150,15 @@ py_test(
],
)
+py_binary(
+ name = "convert_saved_model_to_frozen_graph",
+ srcs = ["convert_saved_model_to_frozen_graph.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convert_saved_model",
+ ],
+)
+
# Transitive dependencies of this target will be included in the pip package.
py_library(
name = "tf_lite_py_pip",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
new file mode 100644
index 0000000000..c4200c879b
--- /dev/null
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts a frozen graph into a TFLite FlatBuffer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os as _os
+import subprocess as _subprocess
+import tempfile as _tempfile
+
+from tensorflow.contrib.lite.python import lite_constants
+from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.python.framework import dtypes as _dtypes
+from tensorflow.python.platform import resource_loader as _resource_loader
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+
+# Lazy load since some of the performance benchmark skylark rules
+# break dependencies.
+_toco_python = LazyLoader(
+ "tensorflow_wrap_toco", globals(),
+ "tensorflow.contrib.lite.toco.python."
+ "tensorflow_wrap_toco")
+del LazyLoader
+
+# Find the toco_from_protos binary using the resource loader if using from
+# bazel, otherwise we are in a pip where console_scripts already has
+# the toco_from_protos tool.
+if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
+ _toco_from_proto_bin = ""
+else:
+ _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
+ "../toco/python/toco_from_protos")
+
+if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
+ _toco_from_proto_bin = "toco_from_protos"
+
+
+def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
+ """Convert `input_data_str` according to model and toco parameters.
+
+ Unless you know what you are doing consider using
+ the more friendly @{tf.contrib.lite.toco_convert}}.
+
+ Args:
+ model_flags_str: Serialized proto describing model properties, see
+ `toco/model_flags.proto`.
+ toco_flags_str: Serialized proto describing conversion properties, see
+ `toco/toco_flags.proto`.
+ input_data_str: Input data in serialized form (e.g. a graphdef is common)
+ Returns:
+ Converted model in serialized form (e.g. a TFLITE model is common).
+ Raises:
+ RuntimeError: When conversion fails, an exception is raised with the error
+ message embedded.
+ """
+ # TODO(aselle): When toco does not use fatal errors for failure, we can
+ # switch this on.
+ if not _toco_from_proto_bin:
+ return _toco_python.TocoConvert(
+ model_flags_str, toco_flags_str, input_data_str)
+
+ with _tempfile.NamedTemporaryFile() as fp_toco, \
+ _tempfile.NamedTemporaryFile() as fp_model, \
+ _tempfile.NamedTemporaryFile() as fp_input, \
+ _tempfile.NamedTemporaryFile() as fp_output:
+ fp_model.write(model_flags_str)
+ fp_toco.write(toco_flags_str)
+ fp_input.write(input_data_str)
+ fp_model.flush()
+ fp_toco.flush()
+ fp_input.flush()
+
+ cmd = [
+ _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name,
+ fp_output.name
+ ]
+ cmdline = " ".join(cmd)
+ proc = _subprocess.Popen(
+ cmdline,
+ shell=True,
+ stdout=_subprocess.PIPE,
+ stderr=_subprocess.STDOUT,
+ close_fds=True)
+ stdout, stderr = proc.communicate()
+ exitcode = proc.returncode
+ if exitcode == 0:
+ stuff = fp_output.read()
+ return stuff
+ else:
+ raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
+ (stdout, stderr))
+
+
+def tensor_name(x):
+ return x.name.split(":")[0]
+
+
+def toco_convert(input_data,
+ input_tensors,
+ output_tensors,
+ inference_type=lite_constants.FLOAT,
+ input_format=lite_constants.TENSORFLOW_GRAPHDEF,
+ output_format=lite_constants.TFLITE,
+ quantized_input_stats=None,
+ drop_control_dependency=True):
+ """Convert a model using TOCO from `input_format` to `output_format`.
+
+ Typically this is to convert from TensorFlow GraphDef to TFLite, in which
+ case the default `input_format` and `output_format` are sufficient.
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`).
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
+ input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
+ output_format: Type of data to write (currently must be TFLITE or
+ GRAPHVIZ_DOT)
+ quantized_input_stats: For each member of input_tensors the mean and
+ std deviation of training data. Only needed if `inference_type` is
+ `QUANTIZED_UINT8`.
+ drop_control_dependency: Drops control dependencies silently. This is due
+ to tf lite not supporting control dependencies.
+
+ Returns:
+ The converted data. For example if tflite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ ValueError: If the input tensor type is unknown
+ RuntimeError: If TOCO fails to convert (in which case the runtime error's
+ error text will contain the TOCO error log)
+ """
+ toco = _toco_flags_pb2.TocoFlags()
+ toco.input_format = input_format
+ toco.output_format = output_format
+ toco.drop_control_dependency = drop_control_dependency
+ model = _model_flags_pb2.ModelFlags()
+ toco.inference_type = inference_type
+ for idx, input_tensor in enumerate(input_tensors):
+ if input_tensor.dtype == _dtypes.float32:
+ tflite_input_type = lite_constants.FLOAT
+ elif input_tensor.dtype == _dtypes.int32:
+ tflite_input_type = lite_constants.INT32
+ elif input_tensor.dtype == _dtypes.int64:
+ tflite_input_type = lite_constants.INT64
+ # TODO(aselle): Insert strings when they are available
+ else:
+ raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
+ input_tensor.dtype))
+
+ input_array = model.input_arrays.add()
+
+ if inference_type == lite_constants.QUANTIZED_UINT8:
+ if tflite_input_type == lite_constants.FLOAT:
+ tflite_input_type = lite_constants.QUANTIZED_UINT8
+ input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
+
+ input_array.name = tensor_name(input_tensor)
+ input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
+
+ for output_tensor in output_tensors:
+ model.output_arrays.append(tensor_name(output_tensor))
+
+ # TODO(aselle): Consider handling the case of allowing quantized
+ # inputs to be converted to float (via the toco.inference_input_type field).
+ data = toco_convert_protos(model.SerializeToString(),
+ toco.SerializeToString(),
+ input_data.SerializeToString())
+ return data
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index a2b5ef488e..a7eddf3408 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -12,52 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-r"""TensorFlow Lite flatbuffer generation from saved_models.
+"""Functions to convert SavedModel to frozen GraphDefs."""
-Example:
-
-bazel run third_party/tensorflow/contrib/lite/python:convert_saved_model -- \
- --saved_model_dir=/tmp/test_saved_model/1519865537 \
- --output_tflite=/tmp/test.lite
-
-"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python import convert
+from tensorflow.contrib.lite.python import lite_constants
+from tensorflow.contrib.lite.toco import model_flags_pb2
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
-flags.DEFINE_string("saved_model_dir", "", "Saved model directory to convert.")
-flags.DEFINE_string("output_tflite", None, "File path to write flatbuffer.")
-flags.DEFINE_string("output_arrays", None,
- "List of output tensor names, the default value is None, "
- "which means the conversion will keep all outputs.")
-flags.DEFINE_integer("batch_size", 1,
- "If input tensor shape has None at first dimension, "
- "e.g. (None,224,224,3), replace None with batch_size.")
-flags.DEFINE_string("tag_set", tag_constants.SERVING,
- "Group of tag(s) of the MetaGraphDef in the saved_model, "
- "in string format, separated by ','. For tag-set contains "
- "multiple tags, all tags must be passed in.")
-flags.DEFINE_string("signature_key",
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- "This is signature key to extract inputs, outputs.")
-
-
-def log_tensor_details(tensor_info):
+
+def _write_and_flush_file(file_path, data_str):
+ """Writes data to file path.
+
+ Args:
+ file_path: Full path of the file to store data in.
+ data_str: Data represented as a string.
+
+ Returns: None.
+ """
+ with gfile.Open(file_path, "wb") as data_file:
+ data_file.write(data_str)
+ data_file.flush()
+
+
+def _log_tensor_details(tensor_info):
"""Log tensor details: name, shape, and type."""
for key in tensor_info:
val = tensor_info[key]
@@ -73,7 +64,7 @@ def log_tensor_details(tensor_info):
dtype)
-def get_meta_graph_def(saved_model_dir, tag_set):
+def _get_meta_graph_def(saved_model_dir, tag_set):
"""Validate saved_model and extract MetaGraphDef.
Args:
@@ -103,7 +94,7 @@ def get_meta_graph_def(saved_model_dir, tag_set):
"values are '{}'. ".format(tag_set, tag_sets))
-def get_signature_def(meta_graph, signature_key):
+def _get_signature_def(meta_graph, signature_key):
"""Get the signature def from meta_graph with given signature_key.
Args:
@@ -130,11 +121,11 @@ def get_signature_def(meta_graph, signature_key):
return signature_def
-def get_inputs_outputs(signature_def):
- """Get inputs and outputs from signature def.
+def _get_inputs_outputs(signature_def):
+ """Get inputs and outputs from SignatureDef.
Args:
- signature_def: signatuer def in the meta_graph_def for conversion.
+ signature_def: SignatureDef in the meta_graph_def for conversion.
Returns:
The inputs and outputs in the graph for conversion.
@@ -142,9 +133,9 @@ def get_inputs_outputs(signature_def):
inputs_tensor_info = signature_def.inputs
outputs_tensor_info = signature_def.outputs
logging.info("input tensors info: ")
- log_tensor_details(inputs_tensor_info)
+ _log_tensor_details(inputs_tensor_info)
logging.info("output tensors info: ")
- log_tensor_details(outputs_tensor_info)
+ _log_tensor_details(outputs_tensor_info)
def gather_names(tensor_info):
return [tensor_info[key].name for key in tensor_info]
@@ -154,109 +145,277 @@ def get_inputs_outputs(signature_def):
return inputs, outputs
-def convert(saved_model_dir,
- output_tflite=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- batch_size=1):
- """Convert a saved_model to tflite flatbuffer.
+def _get_tensors(graph, signature_def_tensor_names=None,
+ user_tensor_names=None):
+ """Gets the tensors associated with the tensor names.
+
+ Either signature_def_tensor_names or user_tensor_names should be provided. If
+ the user provides tensors, the tensors associated with the user provided
+ tensor names are provided. Otherwise, the tensors associated with the names in
+ the SignatureDef are provided.
Args:
- saved_model_dir: Saved model directory to convert.
- output_tflite: File path to write result flatbuffer.
- output_arrays: List of output tensor names, the default value is None, which
- means conversion keeps all output tensors. This is also used to filter
- tensors that are from Op currently not supported in tflite, e.g., Argmax).
- tag_set: This is the set of tags to get meta_graph_def in saved_model.
- signature_key: This is the signature key to extract inputs, outputs.
- batch_size: If input tensor shape has None at first dimension,
- e.g. (None,224,224,3), replace None with batch_size.
+ graph: GraphDef representing graph.
+ signature_def_tensor_names: Tensor names stored in either the inputs or
+ outputs of a SignatureDef. (default None)
+ user_tensor_names: Tensor names provided by the user. (default None)
Returns:
- The converted data. For example if tflite was the destination, then
- this will be a tflite flatbuffer in a bytes array.
+ List of tensors.
+
+ Raises:
+ ValueError:
+ signature_def_tensors and user_tensor_names are undefined or empty.
+ user_tensor_names are not valid.
+ """
+ tensors = []
+ if user_tensor_names:
+ # Get the list of all of the tensors with and without the tensor index.
+ all_tensor_names = [
+ tensor.name for op in graph.get_operations() for tensor in op.outputs
+ ]
+ all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names]
+
+ # Sort the tensor names.
+ user_tensor_names = sorted(user_tensor_names)
+
+ # Get the tensors associated with the tensor names.
+ tensors = []
+ invalid_tensors = []
+ for name in user_tensor_names:
+ if name not in all_tensor_names_only:
+ invalid_tensors.append(name)
+ else:
+ idx = all_tensor_names_only.index(name)
+ tensors.append(graph.get_tensor_by_name(all_tensor_names[idx]))
+
+ # Throw ValueError if any user input names are not valid tensors.
+ if invalid_tensors:
+ raise ValueError("Invalid tensors '{}' were found.".format(
+ ",".join(invalid_tensors)))
+ elif signature_def_tensor_names:
+ tensors = [
+ graph.get_tensor_by_name(name)
+ for name in sorted(signature_def_tensor_names)
+ ]
+ else:
+ # Throw ValueError if signature_def_tensors and user_tensor_names are both
+ # either undefined or empty.
+ raise ValueError(
+ "Specify either signature_def_tensor_names or user_tensor_names")
+
+ return tensors
+
+
+def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
+ output_arrays, tag_set, signature_key, batch_size):
+ """Converts a SavedModel to a frozen graph.
+
+ Args:
+ saved_model_dir: SavedModel directory to convert.
+ input_arrays: List of input tensors to freeze graph with. Uses input arrays
+ from SignatureDef when none are provided. (default None)
+ input_shapes: Map of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present. (default "serve")
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ batch_size: Batch size for the model. Replaces the first dimension of an
+ input size array if undefined. (default 1)
+
+ Returns:
+ frozen_graph_def: Frozen GraphDef.
+ in_tensors: List of input tensors for the graph.
+ out_tensors: List of output tensors for the graph.
Raises:
- ValueError: If tag_set does not indicate any meta_graph_def in saved_model,
- or signature_key is not in relevant meta_graph_def,
- or input shape has None beyond 1st dimension, e.g., (1,None, None, 3),
- or given output_arrays are not valid causing empty outputs.
+ ValueError:
+ SavedModel doesn't contain a MetaGraphDef identified by tag_set.
+ signature_key is not in the MetaGraphDef.
+ input_shapes does not match the length of input_arrays.
+ input_shapes has a None value after the 1st dimension.
+ input_arrays or output_arrays are not valid.
+ Unable to load Session.
"""
+ # Set default values for inputs if they are set to None.
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
if tag_set is None:
tag_set = set([tag_constants.SERVING])
+ if batch_size is None:
+ batch_size = 1
- meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
- signature_def = get_signature_def(meta_graph, signature_key)
- inputs, outputs = get_inputs_outputs(signature_def)
+ # Read SignatureDef.
+ meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
+ signature_def = _get_signature_def(meta_graph, signature_key)
+ inputs, outputs = _get_inputs_outputs(signature_def)
graph = ops.Graph()
with session.Session(graph=graph) as sess:
-
+ # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory.
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
- in_tensors = [graph.get_tensor_by_name(input_) for input_ in inputs]
-
- # Users can use output_arrays to filter output tensors for conversion.
- # If output_arrays is None, we keep all output tensors. In future, we may
- # use tflite supported Op list and check whether op is custom Op to
- # automatically filter output arrays.
- # TODO(zhixianyan): Use tflite supported Op list to filter outputs.
- if output_arrays is not None:
- output_arrays = output_arrays.split(",")
- out_tensors = [
- graph.get_tensor_by_name(output)
- for output in outputs
- if output.split(":")[0] in output_arrays
- ]
- else:
- out_tensors = [graph.get_tensor_by_name(output) for output in outputs]
+ # Gets input and output tensors.
+ # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
+ in_tensors = _get_tensors(graph, inputs, input_arrays)
+ out_tensors = _get_tensors(graph, outputs, output_arrays)
- output_names = [node.split(":")[0] for node in outputs]
+ # Gets fully defined tensor shape. An input tensor with None in the first
+ # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size.
+ # Shapes with None after the first dimension result in a ValueError.
+ # TODO(zhixianyan): Add supports for input tensor with more None in shape.
+ for tensor in in_tensors:
+ if (input_shapes and tensor.name in input_shapes and
+ input_shapes[tensor.name] is not None):
+ shape = input_shapes[tensor.name]
+ else:
+ shape = tensor.get_shape().as_list()
- if not out_tensors:
- raise ValueError(
- "No valid output tensors for '{}', possible values are '{}'".format(
- output_arrays, output_names))
+ if None in shape[1:]:
+ raise ValueError(
+ "None is only supported in the 1st dimension. Tensor '{0}' has "
+ "invalid shape '{1}'.".format(tensor.name, shape))
+ elif shape[0] is None:
+ shape[0] = batch_size
+ tensor.set_shape(shape)
+ output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
- # Toco requires fully defined tensor shape, for input tensor with None in
- # their shape, e.g., (None, 224, 224, 3), we need to replace first None with
- # a given batch size. For shape with more None, e.g. (None, None, None, 3),
- # still be able to replace and convert, but require further investigation.
- # TODO(zhixianyan): Add supports for input tensor with more None in shape.
- for i in range(len(in_tensors)):
- shape = in_tensors[i].get_shape().as_list()
- if shape[0] is None:
- shape[0] = batch_size
- if None in shape[1:]:
- raise ValueError(
- "Only support None shape at 1st dim as batch_size. But tensor "
- "'{}' 's shape '{}' has None at other dimension. ".format(
- inputs[i], shape))
- in_tensors[i].set_shape(shape)
+ return frozen_graph_def, in_tensors, out_tensors
+ raise ValueError("Unable to load Session.")
- result = lite.toco_convert(frozen_graph_def, in_tensors, out_tensors)
- if output_tflite is not None:
- with gfile.Open(output_tflite, "wb") as f:
- f.write(result)
- logging.info("Successfully converted to: %s", output_tflite)
+def saved_model_to_frozen_graphdef(
+ saved_model_dir,
+ output_file_model,
+ output_file_flags,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ batch_size=1):
+ """Converts a SavedModel to a frozen graph. Writes graph to tmp directory.
- return result
+ Stores frozen graph and command line flags in the tmp directory.
+ Args:
+ saved_model_dir: SavedModel directory to convert.
+ output_file_model: Full file path to save frozen graph.
+ output_file_flags: Full file path to save ModelFlags.
+ input_arrays: List of input tensors to freeze graph with. Uses input arrays
+ from SignatureDef when none are provided. (default None)
+ input_shapes: Map of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present. (default "serve")
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ batch_size: Batch size for the model. Replaces the first dimension of an
+ input size array if undefined. (default 1)
+
+ Returns: None.
-def main(_):
- convert(
- saved_model_dir=flags.FLAGS.saved_model_dir,
- output_tflite=flags.FLAGS.output_tflite,
- output_arrays=flags.FLAGS.output_arrays,
- batch_size=flags.FLAGS.batch_size,
- tag_set=set(flags.FLAGS.tag_set.split(",")),
- signature_key=flags.FLAGS.signature_key)
+ Raises:
+ ValueError: Unable to convert to frozen graph.
+ """
+ frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
+ saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
+ signature_key, batch_size)
+
+ # Initialize model flags.
+ model = model_flags_pb2.ModelFlags()
+
+ for input_tensor in in_tensors:
+ input_array = model.input_arrays.add()
+ input_array.name = convert.tensor_name(input_tensor)
+ input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
+
+ for output_tensor in out_tensors:
+ model.output_arrays.append(convert.tensor_name(output_tensor))
+
+ # Write model and ModelFlags to file. ModelFlags contain input array and
+ # output array information that is parsed from the SignatureDef and used for
+ # analysis by TOCO.
+ _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString())
+ _write_and_flush_file(output_file_flags, model.SerializeToString())
+
+
+def tflite_from_saved_model(
+ saved_model_dir,
+ output_file=None,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ batch_size=1,
+ inference_type=lite_constants.FLOAT,
+ input_format=lite_constants.TENSORFLOW_GRAPHDEF,
+ output_format=lite_constants.TFLITE,
+ quantized_input_stats=None,
+ drop_control_dependency=True):
+ """Converts a SavedModel to TFLite FlatBuffer.
+ Args:
+ saved_model_dir: SavedModel directory to convert.
+ output_file: File path to write result TFLite FlatBuffer.
+ input_arrays: List of input tensors to freeze graph with. Uses input arrays
+ from SignatureDef when none are provided. (default None)
+ input_shapes: Map of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present. (default "serve")
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ batch_size: Batch size for the model. Replaces the first dimension of an
+ input size array if undefined. (default 1)
+ inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
+ input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
+ output_format: Type of data to write (currently must be TFLITE or
+ GRAPHVIZ_DOT)
+ quantized_input_stats: For each member of input_tensors the mean and
+ std deviation of training data. Only needed if `inference_type` is
+ `QUANTIZED_UINT8`.
+ drop_control_dependency: Drops control dependencies silently. This is due
+ to tf lite not supporting control dependencies.
-if __name__ == "__main__":
- app.run(main)
+ Returns:
+ The converted data. For example if tflite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ ValueError: Unable to convert to frozen graph.
+ """
+ frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
+ saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
+ signature_key, batch_size)
+
+ result = convert.toco_convert(
+ input_data=frozen_graph_def,
+ input_tensors=in_tensors,
+ output_tensors=out_tensors,
+ inference_type=inference_type,
+ input_format=input_format,
+ output_format=output_format,
+ quantized_input_stats=quantized_input_stats,
+ drop_control_dependency=drop_control_dependency)
+
+ if output_file is not None:
+ with gfile.Open(output_file, "wb") as f:
+ f.write(result)
+ logging.info("Successfully converted to: %s", output_file)
+
+ return result
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py
index 734e42d619..db95fc8ad7 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model_test.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""TF Lite SavedModel Conversion test cases.
-
- - test on generated saved_models from simple graphs (sanity check)
- - test mnist savedmodel generated on-the-fly
+"""TFLite SavedModel conversion test cases.
+ - Tests converting simple SavedModel graph to TFLite FlatBuffer.
+ - Tests converting simple SavedModel graph to frozen graph.
+ - Tests converting MNIST SavedModel to TFLite FlatBuffer.
"""
from __future__ import absolute_import
@@ -25,6 +25,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.lite.python import convert_saved_model
+from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator_lib as estimator
@@ -37,6 +38,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training import training as train
@@ -45,7 +47,7 @@ from tensorflow.python.training import training as train
class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
def _createSimpleSavedModel(self, shape):
- """Create a simple savedmodel on the fly."""
+ """Create a simple SavedModel on the fly."""
saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
with session.Session() as sess:
in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32)
@@ -56,44 +58,78 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
return saved_model_dir
def testSimpleSavedModel(self):
- """Test a simple savedmodel created on the fly."""
- # Create a simple savedmodel
+ """Test a simple SavedModel created on the fly."""
+ # Create a simple SavedModel
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
# Convert to tflite
- result = convert_saved_model.convert(saved_model_dir=saved_model_dir)
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir)
self.assertTrue(result)
def testSimpleSavedModelWithNoneBatchSizeInShape(self):
- """Test a simple savedmodel, with None in input tensor's shape."""
+ """Test a simple SavedModel, with None in input tensor's shape."""
saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
- result = convert_saved_model.convert(saved_model_dir=saved_model_dir)
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir)
self.assertTrue(result)
def testSimpleSavedModelWithMoreNoneInShape(self):
- """Test a simple savedmodel, fail as more None in input shape."""
+ """Test a simple SavedModel, fail as more None in input shape."""
saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3])
# Convert to tflite: this should raise ValueError, as 3rd dim is None.
with self.assertRaises(ValueError):
- convert_saved_model.convert(saved_model_dir=saved_model_dir)
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir)
def testSimpleSavedModelWithWrongSignatureKey(self):
- """Test a simple savedmodel, fail as given signature is invalid."""
+ """Test a simple SavedModel, fail as given signature is invalid."""
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
# Convert to tflite: this should raise ValueError, as
# signature_key does not exit in the saved_model.
with self.assertRaises(ValueError):
- convert_saved_model.convert(
+ convert_saved_model.tflite_from_saved_model(
saved_model_dir=saved_model_dir, signature_key="wrong-key")
def testSimpleSavedModelWithWrongOutputArray(self):
- """Test a simple savedmodel, fail as given output_arrays is invalid."""
- # Create a simple savedmodel
+ """Test a simple SavedModel, fail as given output_arrays is invalid."""
+ # Create a simple SavedModel
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
# Convert to tflite: this should raise ValueError, as
# output_arrays is not valid for the saved_model.
with self.assertRaises(ValueError):
- convert_saved_model.convert(
- saved_model_dir=saved_model_dir, output_arrays="wrong-output")
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir, output_arrays=["wrong-output"])
+
+ def testSimpleSavedModelWithWrongInputArrays(self):
+ """Test a simple SavedModel, fail as given input_arrays is invalid."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ # Checks invalid input_arrays.
+ with self.assertRaises(ValueError):
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir, input_arrays=["wrong-input"])
+ # Checks valid and invalid input_arrays.
+ with self.assertRaises(ValueError):
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir,
+ input_arrays=["Placeholder", "wrong-input"])
+
+ def testSimpleSavedModelWithCorrectArrays(self):
+ """Test a simple SavedModel, with correct input_arrays and output_arrays."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir,
+ input_arrays=["Placeholder"],
+ output_arrays=["add"])
+ self.assertTrue(result)
+
+ def testSimpleSavedModelWithCorrectInputArrays(self):
+ """Test a simple SavedModel, with correct input_arrays and input_shapes."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir,
+ input_arrays=["Placeholder"],
+ input_shapes={"Placeholder": [1, 16, 16, 3]})
+ self.assertTrue(result)
def testMultipleMetaGraphDef(self):
"""Test saved model with multiple MetaGraphDef."""
@@ -119,20 +155,103 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
sess,
tags=[saved_model.tag_constants.SERVING, "additional_test_tag"],
signature_def_map=signature_def_map)
+
# MetaGraphDef 2
builder.add_meta_graph(tags=["tflite"])
builder.save(True)
# Convert to tflite
- convert_saved_model.convert(
+ convert_saved_model.tflite_from_saved_model(
saved_model_dir=saved_model_dir,
tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"]))
+class ConvertSavedModelTestBasicGraphToText(test_util.TensorFlowTestCase):
+
+ def _createSimpleSavedModel(self, shape):
+ """Create a simple SavedModel."""
+ saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
+ with session.Session() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name="inputB")
+ in_tensor_2 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name="inputA")
+ out_tensor = in_tensor_1 + in_tensor_2
+ inputs = {"x": in_tensor_1, "y": in_tensor_2}
+ outputs = {"z": out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ return saved_model_dir
+
+ def _getInputArrayNames(self, model_proto):
+ return [data.name for data in model_proto.input_arrays]
+
+ def _getInputArrayShapes(self, model_proto):
+ return [
+ [dim for dim in data.shape.dims] for data in model_proto.input_arrays
+ ]
+
+ def _get_model_flags_proto_from_file(self, filename):
+ proto = _model_flags_pb2.ModelFlags()
+ with gfile.Open(filename, "rb") as output_file:
+ proto.ParseFromString(output_file.read())
+ output_file.close()
+ return proto
+
+ def testSimpleSavedModel(self):
+ """Test a simple SavedModel."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ output_file_model = os.path.join(self.get_temp_dir(), "model.pb")
+ output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt")
+
+ convert_saved_model.saved_model_to_frozen_graphdef(
+ saved_model_dir=saved_model_dir,
+ output_file_model=output_file_model,
+ output_file_flags=output_file_flags,
+ input_arrays=["inputB", "inputA"])
+
+ proto = self._get_model_flags_proto_from_file(output_file_flags)
+ self.assertEqual(proto.output_arrays, ["add"])
+ self.assertEqual(self._getInputArrayNames(proto), ["inputA", "inputB"])
+ self.assertEqual(
+ self._getInputArrayShapes(proto), [[1, 16, 16, 3], [1, 16, 16, 3]])
+
+ def testSimpleSavedModelWithDifferentInputNames(self):
+ """Test a simple SavedModel."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ output_file_model = os.path.join(self.get_temp_dir(), "model.pb")
+ output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt")
+
+ # Check case where input shape is given.
+ convert_saved_model.saved_model_to_frozen_graphdef(
+ saved_model_dir=saved_model_dir,
+ output_file_model=output_file_model,
+ output_file_flags=output_file_flags,
+ input_arrays=["inputA"],
+ input_shapes={"inputA": [1, 16, 16, 3]})
+
+ proto = self._get_model_flags_proto_from_file(output_file_flags)
+ self.assertEqual(proto.output_arrays, ["add"])
+ self.assertEqual(self._getInputArrayNames(proto), ["inputA"])
+ self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]])
+
+ # Check case where input shape is None.
+ convert_saved_model.saved_model_to_frozen_graphdef(
+ saved_model_dir=saved_model_dir,
+ output_file_model=output_file_model,
+ output_file_flags=output_file_flags,
+ input_arrays=["inputA"],
+ input_shapes={"inputA": None})
+
+ proto = self._get_model_flags_proto_from_file(output_file_flags)
+ self.assertEqual(proto.output_arrays, ["add"])
+ self.assertEqual(self._getInputArrayNames(proto), ["inputA"])
+ self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]])
+
+
class Model(keras.Model):
"""Model to recognize digits in the MNIST dataset.
- Train and export savedmodel, used for testOnflyTrainMnistSavedModel
+ Train and export SavedModel, used for testOnflyTrainMnistSavedModel
Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
@@ -238,7 +357,7 @@ def dummy_input_fn():
class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
def testTrainedMnistSavedModel(self):
- """Test mnist savedmodel, trained with dummy data and small steps."""
+ """Test mnist SavedModel, trained with dummy data and small steps."""
# Build classifier
classifier = estimator.Estimator(
model_fn=model_fn,
@@ -253,21 +372,20 @@ class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
"image": image,
})
- # Export savedmodel
+ # Export SavedModel
saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel")
classifier.export_savedmodel(saved_model_dir, pred_input_fn)
# Convert to tflite and test output
saved_model_name = os.listdir(saved_model_dir)[0]
saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name)
- output_tflite = os.path.join(saved_model_dir,
- saved_model_final_dir + ".lite")
+ output_file = os.path.join(saved_model_dir, saved_model_final_dir + ".lite")
# TODO(zhixianyan): no need to limit output_arrays to `Softmax'
# once b/74205001 fixed and argmax implemented in tflite.
- result = convert_saved_model.convert(
+ result = convert_saved_model.tflite_from_saved_model(
saved_model_dir=saved_model_final_dir,
- output_arrays="Softmax",
- output_tflite=output_tflite)
+ output_arrays=["Softmax"],
+ output_file=output_file)
self.assertTrue(result)
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py
new file mode 100644
index 0000000000..4d9782f4a6
--- /dev/null
+++ b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py
@@ -0,0 +1,106 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python console command for generating frozen models from SavedModels.
+
+This exists to add SavedModel compatibility to TOCO.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+from tensorflow.contrib.lite.python.convert_saved_model import saved_model_to_frozen_graphdef
+from tensorflow.python.platform import app
+
+FLAGS = None
+
+
+def execute(unused_args):
+ """Calls function to convert the SavedModel to a frozen graph."""
+ # Error handling.
+ if FLAGS.input_shapes and not FLAGS.input_arrays:
+ raise ValueError("Input shapes requires input arrays to be specified.")
+
+ # Calls saved_model_to_frozen_graphdef function to generate frozen graph.
+ input_arrays = (FLAGS.input_arrays.split(",") if FLAGS.input_arrays else None)
+ input_shapes = None
+ if FLAGS.input_shapes:
+ input_shapes = {
+ input_arrays[idx]: shape.split(",")
+ for idx, shape in enumerate(FLAGS.input_shapes.split(":"))
+ }
+ output_arrays = (
+ FLAGS.output_arrays.split(",") if FLAGS.output_arrays else None)
+ tag_set = set(FLAGS.tag_set.split(",")) if FLAGS.tag_set else None
+
+ saved_model_to_frozen_graphdef(
+ saved_model_dir=FLAGS.saved_model_directory,
+ output_file_model=FLAGS.output_file_model,
+ output_file_flags=FLAGS.output_file_flags,
+ input_arrays=input_arrays,
+ input_shapes=input_shapes,
+ output_arrays=output_arrays,
+ tag_set=tag_set,
+ signature_key=FLAGS.signature_key,
+ batch_size=FLAGS.batch_size)
+
+
+def main():
+ global FLAGS
+ # Parses flags.
+ parser = argparse.ArgumentParser(
+ description="Invoke SavedModel to frozen model converter.")
+ parser.add_argument(
+ "saved_model_directory",
+ type=str,
+ help="Full path to directory containing the SavedModel.")
+ parser.add_argument(
+ "output_file_model",
+ type=str,
+ help="Full file path to save frozen graph.")
+ parser.add_argument(
+ "output_file_flags", type=str, help="Full file path to save ModelFlags.")
+ parser.add_argument(
+ "--input_arrays",
+ type=str,
+ help="Name of the input arrays, comma-separated.")
+ parser.add_argument(
+ "--input_shapes",
+ type=str,
+ help="Shapes corresponding to --input_arrays, colon-separated.")
+ parser.add_argument(
+ "--output_arrays",
+ type=str,
+ help="Name of the output arrays, comma-separated.")
+ parser.add_argument(
+ "--tag_set", type=str, help="Name of output arrays, comma-separated.")
+ parser.add_argument(
+ "--signature_key",
+ type=str,
+ help="Key identifying SignatureDef containing inputs and outputs.")
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ help="Batch size for the model. Replaces the first dimension of an "
+ "input size array if undefined.")
+
+ FLAGS, unparsed = parser.parse_known_args()
+
+ app.run(main=execute, argv=[sys.argv[0]] + unparsed)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/convert_test.py
index b8b4510188..dc21a9b669 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -17,8 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.lite.python import lite
-from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base
+from tensorflow.contrib.lite.python import convert
+from tensorflow.contrib.lite.python import lite_constants
+from tensorflow.contrib.lite.python import op_hint
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class LiteTest(test_util.TensorFlowTestCase):
+class ConvertTest(test_util.TensorFlowTestCase):
def testBasic(self):
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
@@ -37,13 +38,13 @@ class LiteTest(test_util.TensorFlowTestCase):
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Try running on valid graph
- result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
+ result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
self.assertTrue(result)
# TODO(aselle): remove tests that fail (we must get TOCO to not fatal
# all the time).
# Try running on identity graph (known fail)
# with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
- # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
+ # result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
def testQuantization(self):
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
@@ -51,13 +52,14 @@ class LiteTest(test_util.TensorFlowTestCase):
out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
min=0., max=1.)
sess = session.Session()
- result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor],
- inference_type=lite.QUANTIZED_UINT8,
- quantized_input_stats=[(0., 1.)])
+ result = convert.toco_convert(
+ sess.graph_def, [in_tensor], [out_tensor],
+ inference_type=lite_constants.QUANTIZED_UINT8,
+ quantized_input_stats=[(0., 1.)])
self.assertTrue(result)
-class LiteTestOpHint(test_util.TensorFlowTestCase):
+class ConvertTestOpHint(test_util.TensorFlowTestCase):
"""Test the hint to stub functionality."""
def _getGraphOpTypes(self, graphdef, output_nodes):
@@ -99,7 +101,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
swish_scale = array_ops.constant(1.0)
def _swish(input_tensor, scale):
- custom = lite.OpHint("cool_activation")
+ custom = op_hint.OpHint("cool_activation")
input_tensor, scale = custom.add_inputs(input_tensor, scale)
output = math_ops.sigmoid(input_tensor) * input_tensor * scale
output, = custom.add_outputs(output)
@@ -111,11 +113,12 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
# and 1 final output).
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
- stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
self.assertCountEqual(
self._getGraphOpTypes(
- stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output)]),
["cool_activation", "Const", "Identity"])
def testScaleAndBiasAndIdentity(self):
@@ -125,7 +128,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
b = array_ops.constant([4., 5.])
def _scaled_and_bias_and_identity(a, x, b):
- custom = lite.OpHint("scale_and_bias_and_identity")
+ custom = op_hint.OpHint("scale_and_bias_and_identity")
a, x, b = custom.add_inputs(a, x, b)
return custom.add_outputs(a * x + b, x)
output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
@@ -136,11 +139,12 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
- stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
self.assertCountEqual(
self._getGraphOpTypes(
- stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output)]),
["scale_and_bias_and_identity", "Const", "Identity", "Pack"])
def testTwoFunctions(self):
@@ -148,7 +152,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
a = array_ops.constant([1.])
b = array_ops.constant([1.])
def _double_values(x):
- custom = lite.OpHint("add_test")
+ custom = op_hint.OpHint("add_test")
x = custom.add_inputs(x)
output = math_ops.multiply(x, x)
output, = custom.add_outputs(output)
@@ -160,10 +164,11 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
# make sure one identity for each input (2) and output (2) => 2 + 2
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
- stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
self.assertCountEqual(
self._getGraphOpTypes(
- stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output)]),
["add_test", "Const", "Identity", "Add"])
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index cf50f9d4d6..4ea40201f7 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -18,6 +18,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
@@toco_convert
@@toco_convert_protos
+@@tflite_from_saved_model
@@OpHint
@@convert_op_hints_to_stubs
@@ -25,208 +26,11 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os as _os
-import subprocess as _subprocess
-import tempfile as _tempfile
# pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert
+from tensorflow.contrib.lite.python.convert import toco_convert_protos
+from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.contrib.lite.python.op_hint import OpHint
# pylint: enable=unused-import
-from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
-from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
-from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
-from tensorflow.python.framework import dtypes as _dtypes
-from tensorflow.python.platform import resource_loader as _resource_loader
-from tensorflow.python.util.all_util import remove_undocumented
-from tensorflow.python.util.lazy_loader import LazyLoader
-
-# Lazy load since some of the performance benchmark skylark rules
-# break dependencies.
-_toco_python = LazyLoader(
- "tensorflow_wrap_toco", globals(),
- "tensorflow.contrib.lite.toco.python."
- "tensorflow_wrap_toco")
-del LazyLoader
-
-# Enum types from the protobuf promoted to the API
-FLOAT = _types_pb2.FLOAT
-INT32 = _types_pb2.INT32
-INT64 = _types_pb2.INT64
-STRING = _types_pb2.STRING
-QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8
-TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
-TFLITE = _toco_flags_pb2.TFLITE
-GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
-
-# Currently the default mode of operation is to shell to another python process
-# to protect against crashes. However, it breaks some dependent targets because
-# it forces us to depend on an external py_binary. The experimental API doesn't
-# have that drawback.
-EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
-
-# Find the toco_from_protos binary using the resource loader if using from
-# bazel, otherwise we are in a pip where console_scripts already has
-# the toco_from_protos tool.
-if EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
- _toco_from_proto_bin = ""
-else:
- _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
- "../toco/python/toco_from_protos")
-
-if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
- _toco_from_proto_bin = "toco_from_protos"
-
-
-def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
- """Convert `input_data_str` according to model and toco parameters.
-
- Unless you know what you are doing consider using
- the more friendly @{tf.contrib.lite.toco_convert}}.
-
- Args:
- model_flags_str: Serialized proto describing model properties, see
- `toco/model_flags.proto`.
- toco_flags_str: Serialized proto describing conversion properties, see
- `toco/toco_flags.proto`.
- input_data_str: Input data in serialized form (e.g. a graphdef is common)
- Returns:
- Converted model in serialized form (e.g. a TFLITE model is common).
- Raises:
- RuntimeError: When conversion fails, an exception is raised with the error
- message embedded.
- """
- # TODO(aselle): When toco does not use fatal errors for failure, we can
- # switch this on.
- if not _toco_from_proto_bin:
- return _toco_python.TocoConvert(
- model_flags_str, toco_flags_str, input_data_str)
-
- with _tempfile.NamedTemporaryFile() as fp_toco, \
- _tempfile.NamedTemporaryFile() as fp_model, \
- _tempfile.NamedTemporaryFile() as fp_input, \
- _tempfile.NamedTemporaryFile() as fp_output:
- fp_model.write(model_flags_str)
- fp_toco.write(toco_flags_str)
- fp_input.write(input_data_str)
- fp_model.flush()
- fp_toco.flush()
- fp_input.flush()
-
- cmd = [
- _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name,
- fp_output.name
- ]
- cmdline = " ".join(cmd)
- proc = _subprocess.Popen(
- cmdline,
- shell=True,
- stdout=_subprocess.PIPE,
- stderr=_subprocess.STDOUT,
- close_fds=True)
- stdout, stderr = proc.communicate()
- exitcode = proc.returncode
- if exitcode == 0:
- stuff = fp_output.read()
- return stuff
- else:
- raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
- (stdout, stderr))
-
-
-def _tensor_name(x):
- return x.name.split(":")[0]
-
-
-def toco_convert(input_data,
- input_tensors,
- output_tensors,
- inference_type=FLOAT,
- input_format=TENSORFLOW_GRAPHDEF,
- output_format=TFLITE,
- quantized_input_stats=None,
- drop_control_dependency=True,
- allow_custom_ops=None):
- """Convert a model using TOCO from `input_format` to `output_format`.
-
- Typically this is to convert from TensorFlow GraphDef to TFLite, in which
- case the default `input_format` and `output_format` are sufficient.
-
- Args:
- input_data: Input data (i.e. often `sess.graph_def`).
- input_tensors: List of input tensors. Type and shape are computed using
- `foo.get_shape()` and `foo.dtype`.
- output_tensors: List of output tensors (only .name is used from this).
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT)
- quantized_input_stats: For each member of input_tensors the mean and
- std deviation of training data. Only needed if `inference_type` is
- `QUANTIZED_UINT8`.
- drop_control_dependency: Drops control dependencies silently. This is due
- to tf lite not supporting control dependencies.
-
- Returns:
- The converted data. For example if tflite was the destination, then
- this will be a tflite flatbuffer in a bytes array.
-
- Raises:
- ValueError: If the input tensor type is unknown
- RuntimeError: If TOCO fails to convert (in which case the runtime error's
- error text will contain the TOCO error log)
- """
- toco = _toco_flags_pb2.TocoFlags()
- toco.input_format = input_format
- toco.output_format = output_format
- toco.inference_type = inference_type
- toco.drop_control_dependency = drop_control_dependency
- if allow_custom_ops is not None:
- toco.allow_custom_ops = allow_custom_ops
-
- model = _model_flags_pb2.ModelFlags()
- for idx, input_tensor in enumerate(input_tensors):
- if input_tensor.dtype == _dtypes.float32:
- tflite_input_type = FLOAT
- elif input_tensor.dtype == _dtypes.int32:
- tflite_input_type = INT32
- elif input_tensor.dtype == _dtypes.int64:
- tflite_input_type = INT64
- # TODO(aselle): Insert strings when they are available
- else:
- raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
- input_tensor.dtype))
-
- input_array = model.input_arrays.add()
-
- if inference_type == QUANTIZED_UINT8:
- if tflite_input_type == FLOAT:
- tflite_input_type = QUANTIZED_UINT8
- input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
-
- input_array.name = _tensor_name(input_tensor)
- input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
-
- for output_tensor in output_tensors:
- model.output_arrays.append(_tensor_name(output_tensor))
-
- # TODO(aselle): Consider handling the case of allowing quantized
- # inputs to be converted to float (via the toco.inference_input_type field).
- data = toco_convert_protos(model.SerializeToString(),
- toco.SerializeToString(),
- input_data.SerializeToString())
- return data
-
-
-_allowed_symbols = [
- "FLOAT",
- "INT32",
- "INT64",
- "STRING",
- "QUANTIZED_UINT8",
- "TENSORFLOW_GRAPHDEF",
- "TFLITE",
- "GRAPHVIZ_DOT",
- "EXPERIMENTAL_USE_TOCO_API_DIRECTLY",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/lite/python/lite_constants.py b/tensorflow/contrib/lite/python/lite_constants.py
new file mode 100644
index 0000000000..195d7a732f
--- /dev/null
+++ b/tensorflow/contrib/lite/python/lite_constants.py
@@ -0,0 +1,53 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Constants for TFLite."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
+from tensorflow.python.util.all_util import remove_undocumented
+
+# Enum types from the protobuf promoted to the API
+FLOAT = _types_pb2.FLOAT
+INT32 = _types_pb2.INT32
+INT64 = _types_pb2.INT64
+STRING = _types_pb2.STRING
+QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8
+TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
+TFLITE = _toco_flags_pb2.TFLITE
+GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
+
+# Currently the default mode of operation is to shell to another python process
+# to protect against crashes. However, it breaks some dependent targets because
+# it forces us to depend on an external py_binary. The experimental API doesn't
+# have that drawback.
+EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
+
+
+_allowed_symbols = [
+ "FLOAT",
+ "INT32",
+ "INT64",
+ "STRING",
+ "QUANTIZED_UINT8",
+ "TENSORFLOW_GRAPHDEF",
+ "TFLITE",
+ "GRAPHVIZ_DOT",
+ "EXPERIMENTAL_USE_TOCO_API_DIRECTLY",
+]
+remove_undocumented(__name__, _allowed_symbols)