diff options
author | 2018-04-23 17:10:05 -0700 | |
---|---|---|
committer | 2018-04-23 17:12:39 -0700 | |
commit | 771f7b46d631fa510658685d1b84ffbb22ffcd55 (patch) | |
tree | 47d4f9a79eed86b926c09f0bbfc4180ea588bb3b | |
parent | a36e6edab33c7a5bef2f911d4d7bb88ffc8c7de6 (diff) |
Improve TOCO SavedModel support.
PiperOrigin-RevId: 194009891
-rw-r--r-- | tensorflow/contrib/lite/python/BUILD | 45 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert.py | 187 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model.py | 387 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model_test.py | 172 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py | 106 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_test.py (renamed from tensorflow/contrib/lite/python/lite_test.py) | 41 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 204 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite_constants.py | 53 |
8 files changed, 828 insertions, 367 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 926896d609..e6dcc7aa09 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -39,16 +39,35 @@ py_test( py_library( name = "lite", srcs = ["lite.py"], - # data = [ - # "//tensorflow/contrib/lite/toco/python:toco_from_protos", - # ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":convert", + ":convert_saved_model", ":op_hint", + ], +) + +py_library( + name = "lite_constants", + srcs = ["lite_constants.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/lite/toco:toco_flags_proto_py", + ], +) + +py_library( + name = "convert", + srcs = ["convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite_constants", "//tensorflow/contrib/lite/toco:model_flags_proto_py", "//tensorflow/contrib/lite/toco:toco_flags_proto_py", "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco", + "//tensorflow/contrib/lite/toco/python:toco_from_protos", "//tensorflow/python:platform", ], ) @@ -66,15 +85,15 @@ py_library( ) py_test( - name = "lite_test", - srcs = ["lite_test.py"], + name = "convert_test", + srcs = ["convert_test.py"], srcs_version = "PY2AND3", tags = [ "no-internal-py3", "no_oss", ], deps = [ - ":lite", + ":convert", ":op_hint", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -84,13 +103,14 @@ py_test( ], ) -py_binary( +py_library( name = "convert_saved_model", srcs = ["convert_saved_model.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":lite", + ":convert", + ":lite_constants", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", "//tensorflow/python/tools:freeze_graph_lib", @@ -130,6 +150,15 @@ py_test( ], ) +py_binary( + name = "convert_saved_model_to_frozen_graph", + srcs = ["convert_saved_model_to_frozen_graph.py"], + srcs_version = "PY2AND3", + deps = [ + ":convert_saved_model", + ], +) + # Transitive dependencies of this target will be included in the pip package. py_library( name = "tf_lite_py_pip", diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py new file mode 100644 index 0000000000..c4200c879b --- /dev/null +++ b/tensorflow/contrib/lite/python/convert.py @@ -0,0 +1,187 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converts a frozen graph into a TFLite FlatBuffer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os as _os +import subprocess as _subprocess +import tempfile as _tempfile + +from tensorflow.contrib.lite.python import lite_constants +from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.python.framework import dtypes as _dtypes +from tensorflow.python.platform import resource_loader as _resource_loader +from tensorflow.python.util.lazy_loader import LazyLoader + + +# Lazy load since some of the performance benchmark skylark rules +# break dependencies. +_toco_python = LazyLoader( + "tensorflow_wrap_toco", globals(), + "tensorflow.contrib.lite.toco.python." + "tensorflow_wrap_toco") +del LazyLoader + +# Find the toco_from_protos binary using the resource loader if using from +# bazel, otherwise we are in a pip where console_scripts already has +# the toco_from_protos tool. +if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: + _toco_from_proto_bin = "" +else: + _toco_from_proto_bin = _resource_loader.get_path_to_datafile( + "../toco/python/toco_from_protos") + +if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin): + _toco_from_proto_bin = "toco_from_protos" + + +def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): + """Convert `input_data_str` according to model and toco parameters. + + Unless you know what you are doing consider using + the more friendly @{tf.contrib.lite.toco_convert}}. + + Args: + model_flags_str: Serialized proto describing model properties, see + `toco/model_flags.proto`. + toco_flags_str: Serialized proto describing conversion properties, see + `toco/toco_flags.proto`. + input_data_str: Input data in serialized form (e.g. a graphdef is common) + Returns: + Converted model in serialized form (e.g. a TFLITE model is common). + Raises: + RuntimeError: When conversion fails, an exception is raised with the error + message embedded. + """ + # TODO(aselle): When toco does not use fatal errors for failure, we can + # switch this on. + if not _toco_from_proto_bin: + return _toco_python.TocoConvert( + model_flags_str, toco_flags_str, input_data_str) + + with _tempfile.NamedTemporaryFile() as fp_toco, \ + _tempfile.NamedTemporaryFile() as fp_model, \ + _tempfile.NamedTemporaryFile() as fp_input, \ + _tempfile.NamedTemporaryFile() as fp_output: + fp_model.write(model_flags_str) + fp_toco.write(toco_flags_str) + fp_input.write(input_data_str) + fp_model.flush() + fp_toco.flush() + fp_input.flush() + + cmd = [ + _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name, + fp_output.name + ] + cmdline = " ".join(cmd) + proc = _subprocess.Popen( + cmdline, + shell=True, + stdout=_subprocess.PIPE, + stderr=_subprocess.STDOUT, + close_fds=True) + stdout, stderr = proc.communicate() + exitcode = proc.returncode + if exitcode == 0: + stuff = fp_output.read() + return stuff + else: + raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" % + (stdout, stderr)) + + +def tensor_name(x): + return x.name.split(":")[0] + + +def toco_convert(input_data, + input_tensors, + output_tensors, + inference_type=lite_constants.FLOAT, + input_format=lite_constants.TENSORFLOW_GRAPHDEF, + output_format=lite_constants.TFLITE, + quantized_input_stats=None, + drop_control_dependency=True): + """Convert a model using TOCO from `input_format` to `output_format`. + + Typically this is to convert from TensorFlow GraphDef to TFLite, in which + case the default `input_format` and `output_format` are sufficient. + + Args: + input_data: Input data (i.e. often `sess.graph_def`). + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. + input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). + output_format: Type of data to write (currently must be TFLITE or + GRAPHVIZ_DOT) + quantized_input_stats: For each member of input_tensors the mean and + std deviation of training data. Only needed if `inference_type` is + `QUANTIZED_UINT8`. + drop_control_dependency: Drops control dependencies silently. This is due + to tf lite not supporting control dependencies. + + Returns: + The converted data. For example if tflite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + ValueError: If the input tensor type is unknown + RuntimeError: If TOCO fails to convert (in which case the runtime error's + error text will contain the TOCO error log) + """ + toco = _toco_flags_pb2.TocoFlags() + toco.input_format = input_format + toco.output_format = output_format + toco.drop_control_dependency = drop_control_dependency + model = _model_flags_pb2.ModelFlags() + toco.inference_type = inference_type + for idx, input_tensor in enumerate(input_tensors): + if input_tensor.dtype == _dtypes.float32: + tflite_input_type = lite_constants.FLOAT + elif input_tensor.dtype == _dtypes.int32: + tflite_input_type = lite_constants.INT32 + elif input_tensor.dtype == _dtypes.int64: + tflite_input_type = lite_constants.INT64 + # TODO(aselle): Insert strings when they are available + else: + raise ValueError("Tensors %s not known type %r" % (input_tensor.name, + input_tensor.dtype)) + + input_array = model.input_arrays.add() + + if inference_type == lite_constants.QUANTIZED_UINT8: + if tflite_input_type == lite_constants.FLOAT: + tflite_input_type = lite_constants.QUANTIZED_UINT8 + input_array.mean_value, input_array.std_value = quantized_input_stats[idx] + + input_array.name = tensor_name(input_tensor) + input_array.shape.dims.extend(map(int, input_tensor.get_shape())) + + for output_tensor in output_tensors: + model.output_arrays.append(tensor_name(output_tensor)) + + # TODO(aselle): Consider handling the case of allowing quantized + # inputs to be converted to float (via the toco.inference_input_type field). + data = toco_convert_protos(model.SerializeToString(), + toco.SerializeToString(), + input_data.SerializeToString()) + return data diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index a2b5ef488e..a7eddf3408 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -12,52 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -r"""TensorFlow Lite flatbuffer generation from saved_models. +"""Functions to convert SavedModel to frozen GraphDefs.""" -Example: - -bazel run third_party/tensorflow/contrib/lite/python:convert_saved_model -- \ - --saved_model_dir=/tmp/test_saved_model/1519865537 \ - --output_tflite=/tmp/test.lite - -""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import convert +from tensorflow.contrib.lite.python import lite_constants +from tensorflow.contrib.lite.toco import model_flags_pb2 from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops -from tensorflow.python.platform import app -from tensorflow.python.platform import flags from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants -flags.DEFINE_string("saved_model_dir", "", "Saved model directory to convert.") -flags.DEFINE_string("output_tflite", None, "File path to write flatbuffer.") -flags.DEFINE_string("output_arrays", None, - "List of output tensor names, the default value is None, " - "which means the conversion will keep all outputs.") -flags.DEFINE_integer("batch_size", 1, - "If input tensor shape has None at first dimension, " - "e.g. (None,224,224,3), replace None with batch_size.") -flags.DEFINE_string("tag_set", tag_constants.SERVING, - "Group of tag(s) of the MetaGraphDef in the saved_model, " - "in string format, separated by ','. For tag-set contains " - "multiple tags, all tags must be passed in.") -flags.DEFINE_string("signature_key", - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - "This is signature key to extract inputs, outputs.") - - -def log_tensor_details(tensor_info): + +def _write_and_flush_file(file_path, data_str): + """Writes data to file path. + + Args: + file_path: Full path of the file to store data in. + data_str: Data represented as a string. + + Returns: None. + """ + with gfile.Open(file_path, "wb") as data_file: + data_file.write(data_str) + data_file.flush() + + +def _log_tensor_details(tensor_info): """Log tensor details: name, shape, and type.""" for key in tensor_info: val = tensor_info[key] @@ -73,7 +64,7 @@ def log_tensor_details(tensor_info): dtype) -def get_meta_graph_def(saved_model_dir, tag_set): +def _get_meta_graph_def(saved_model_dir, tag_set): """Validate saved_model and extract MetaGraphDef. Args: @@ -103,7 +94,7 @@ def get_meta_graph_def(saved_model_dir, tag_set): "values are '{}'. ".format(tag_set, tag_sets)) -def get_signature_def(meta_graph, signature_key): +def _get_signature_def(meta_graph, signature_key): """Get the signature def from meta_graph with given signature_key. Args: @@ -130,11 +121,11 @@ def get_signature_def(meta_graph, signature_key): return signature_def -def get_inputs_outputs(signature_def): - """Get inputs and outputs from signature def. +def _get_inputs_outputs(signature_def): + """Get inputs and outputs from SignatureDef. Args: - signature_def: signatuer def in the meta_graph_def for conversion. + signature_def: SignatureDef in the meta_graph_def for conversion. Returns: The inputs and outputs in the graph for conversion. @@ -142,9 +133,9 @@ def get_inputs_outputs(signature_def): inputs_tensor_info = signature_def.inputs outputs_tensor_info = signature_def.outputs logging.info("input tensors info: ") - log_tensor_details(inputs_tensor_info) + _log_tensor_details(inputs_tensor_info) logging.info("output tensors info: ") - log_tensor_details(outputs_tensor_info) + _log_tensor_details(outputs_tensor_info) def gather_names(tensor_info): return [tensor_info[key].name for key in tensor_info] @@ -154,109 +145,277 @@ def get_inputs_outputs(signature_def): return inputs, outputs -def convert(saved_model_dir, - output_tflite=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1): - """Convert a saved_model to tflite flatbuffer. +def _get_tensors(graph, signature_def_tensor_names=None, + user_tensor_names=None): + """Gets the tensors associated with the tensor names. + + Either signature_def_tensor_names or user_tensor_names should be provided. If + the user provides tensors, the tensors associated with the user provided + tensor names are provided. Otherwise, the tensors associated with the names in + the SignatureDef are provided. Args: - saved_model_dir: Saved model directory to convert. - output_tflite: File path to write result flatbuffer. - output_arrays: List of output tensor names, the default value is None, which - means conversion keeps all output tensors. This is also used to filter - tensors that are from Op currently not supported in tflite, e.g., Argmax). - tag_set: This is the set of tags to get meta_graph_def in saved_model. - signature_key: This is the signature key to extract inputs, outputs. - batch_size: If input tensor shape has None at first dimension, - e.g. (None,224,224,3), replace None with batch_size. + graph: GraphDef representing graph. + signature_def_tensor_names: Tensor names stored in either the inputs or + outputs of a SignatureDef. (default None) + user_tensor_names: Tensor names provided by the user. (default None) Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. + List of tensors. + + Raises: + ValueError: + signature_def_tensors and user_tensor_names are undefined or empty. + user_tensor_names are not valid. + """ + tensors = [] + if user_tensor_names: + # Get the list of all of the tensors with and without the tensor index. + all_tensor_names = [ + tensor.name for op in graph.get_operations() for tensor in op.outputs + ] + all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names] + + # Sort the tensor names. + user_tensor_names = sorted(user_tensor_names) + + # Get the tensors associated with the tensor names. + tensors = [] + invalid_tensors = [] + for name in user_tensor_names: + if name not in all_tensor_names_only: + invalid_tensors.append(name) + else: + idx = all_tensor_names_only.index(name) + tensors.append(graph.get_tensor_by_name(all_tensor_names[idx])) + + # Throw ValueError if any user input names are not valid tensors. + if invalid_tensors: + raise ValueError("Invalid tensors '{}' were found.".format( + ",".join(invalid_tensors))) + elif signature_def_tensor_names: + tensors = [ + graph.get_tensor_by_name(name) + for name in sorted(signature_def_tensor_names) + ] + else: + # Throw ValueError if signature_def_tensors and user_tensor_names are both + # either undefined or empty. + raise ValueError( + "Specify either signature_def_tensor_names or user_tensor_names") + + return tensors + + +def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key, batch_size): + """Converts a SavedModel to a frozen graph. + + Args: + saved_model_dir: SavedModel directory to convert. + input_arrays: List of input tensors to freeze graph with. Uses input arrays + from SignatureDef when none are provided. (default None) + input_shapes: Map of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : None}). + (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to + analyze. All tags in the tag set must be present. (default "serve") + signature_key: Key identifying SignatureDef containing inputs and outputs. + batch_size: Batch size for the model. Replaces the first dimension of an + input size array if undefined. (default 1) + + Returns: + frozen_graph_def: Frozen GraphDef. + in_tensors: List of input tensors for the graph. + out_tensors: List of output tensors for the graph. Raises: - ValueError: If tag_set does not indicate any meta_graph_def in saved_model, - or signature_key is not in relevant meta_graph_def, - or input shape has None beyond 1st dimension, e.g., (1,None, None, 3), - or given output_arrays are not valid causing empty outputs. + ValueError: + SavedModel doesn't contain a MetaGraphDef identified by tag_set. + signature_key is not in the MetaGraphDef. + input_shapes does not match the length of input_arrays. + input_shapes has a None value after the 1st dimension. + input_arrays or output_arrays are not valid. + Unable to load Session. """ + # Set default values for inputs if they are set to None. + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY if tag_set is None: tag_set = set([tag_constants.SERVING]) + if batch_size is None: + batch_size = 1 - meta_graph = get_meta_graph_def(saved_model_dir, tag_set) - signature_def = get_signature_def(meta_graph, signature_key) - inputs, outputs = get_inputs_outputs(signature_def) + # Read SignatureDef. + meta_graph = _get_meta_graph_def(saved_model_dir, tag_set) + signature_def = _get_signature_def(meta_graph, signature_key) + inputs, outputs = _get_inputs_outputs(signature_def) graph = ops.Graph() with session.Session(graph=graph) as sess: - + # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory. loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) - in_tensors = [graph.get_tensor_by_name(input_) for input_ in inputs] - - # Users can use output_arrays to filter output tensors for conversion. - # If output_arrays is None, we keep all output tensors. In future, we may - # use tflite supported Op list and check whether op is custom Op to - # automatically filter output arrays. - # TODO(zhixianyan): Use tflite supported Op list to filter outputs. - if output_arrays is not None: - output_arrays = output_arrays.split(",") - out_tensors = [ - graph.get_tensor_by_name(output) - for output in outputs - if output.split(":")[0] in output_arrays - ] - else: - out_tensors = [graph.get_tensor_by_name(output) for output in outputs] + # Gets input and output tensors. + # TODO(zhixianyan): Use TFLite supported Op list to filter outputs. + in_tensors = _get_tensors(graph, inputs, input_arrays) + out_tensors = _get_tensors(graph, outputs, output_arrays) - output_names = [node.split(":")[0] for node in outputs] + # Gets fully defined tensor shape. An input tensor with None in the first + # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size. + # Shapes with None after the first dimension result in a ValueError. + # TODO(zhixianyan): Add supports for input tensor with more None in shape. + for tensor in in_tensors: + if (input_shapes and tensor.name in input_shapes and + input_shapes[tensor.name] is not None): + shape = input_shapes[tensor.name] + else: + shape = tensor.get_shape().as_list() - if not out_tensors: - raise ValueError( - "No valid output tensors for '{}', possible values are '{}'".format( - output_arrays, output_names)) + if None in shape[1:]: + raise ValueError( + "None is only supported in the 1st dimension. Tensor '{0}' has " + "invalid shape '{1}'.".format(tensor.name, shape)) + elif shape[0] is None: + shape[0] = batch_size + tensor.set_shape(shape) + output_names = [node.split(":")[0] for node in outputs] frozen_graph_def = tf_graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_names) - # Toco requires fully defined tensor shape, for input tensor with None in - # their shape, e.g., (None, 224, 224, 3), we need to replace first None with - # a given batch size. For shape with more None, e.g. (None, None, None, 3), - # still be able to replace and convert, but require further investigation. - # TODO(zhixianyan): Add supports for input tensor with more None in shape. - for i in range(len(in_tensors)): - shape = in_tensors[i].get_shape().as_list() - if shape[0] is None: - shape[0] = batch_size - if None in shape[1:]: - raise ValueError( - "Only support None shape at 1st dim as batch_size. But tensor " - "'{}' 's shape '{}' has None at other dimension. ".format( - inputs[i], shape)) - in_tensors[i].set_shape(shape) + return frozen_graph_def, in_tensors, out_tensors + raise ValueError("Unable to load Session.") - result = lite.toco_convert(frozen_graph_def, in_tensors, out_tensors) - if output_tflite is not None: - with gfile.Open(output_tflite, "wb") as f: - f.write(result) - logging.info("Successfully converted to: %s", output_tflite) +def saved_model_to_frozen_graphdef( + saved_model_dir, + output_file_model, + output_file_flags, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + batch_size=1): + """Converts a SavedModel to a frozen graph. Writes graph to tmp directory. - return result + Stores frozen graph and command line flags in the tmp directory. + Args: + saved_model_dir: SavedModel directory to convert. + output_file_model: Full file path to save frozen graph. + output_file_flags: Full file path to save ModelFlags. + input_arrays: List of input tensors to freeze graph with. Uses input arrays + from SignatureDef when none are provided. (default None) + input_shapes: Map of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : None}). + (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to + analyze. All tags in the tag set must be present. (default "serve") + signature_key: Key identifying SignatureDef containing inputs and outputs. + batch_size: Batch size for the model. Replaces the first dimension of an + input size array if undefined. (default 1) + + Returns: None. -def main(_): - convert( - saved_model_dir=flags.FLAGS.saved_model_dir, - output_tflite=flags.FLAGS.output_tflite, - output_arrays=flags.FLAGS.output_arrays, - batch_size=flags.FLAGS.batch_size, - tag_set=set(flags.FLAGS.tag_set.split(",")), - signature_key=flags.FLAGS.signature_key) + Raises: + ValueError: Unable to convert to frozen graph. + """ + frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( + saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, + signature_key, batch_size) + + # Initialize model flags. + model = model_flags_pb2.ModelFlags() + + for input_tensor in in_tensors: + input_array = model.input_arrays.add() + input_array.name = convert.tensor_name(input_tensor) + input_array.shape.dims.extend(map(int, input_tensor.get_shape())) + + for output_tensor in out_tensors: + model.output_arrays.append(convert.tensor_name(output_tensor)) + + # Write model and ModelFlags to file. ModelFlags contain input array and + # output array information that is parsed from the SignatureDef and used for + # analysis by TOCO. + _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString()) + _write_and_flush_file(output_file_flags, model.SerializeToString()) + + +def tflite_from_saved_model( + saved_model_dir, + output_file=None, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + batch_size=1, + inference_type=lite_constants.FLOAT, + input_format=lite_constants.TENSORFLOW_GRAPHDEF, + output_format=lite_constants.TFLITE, + quantized_input_stats=None, + drop_control_dependency=True): + """Converts a SavedModel to TFLite FlatBuffer. + Args: + saved_model_dir: SavedModel directory to convert. + output_file: File path to write result TFLite FlatBuffer. + input_arrays: List of input tensors to freeze graph with. Uses input arrays + from SignatureDef when none are provided. (default None) + input_shapes: Map of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : None}). + (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to + analyze. All tags in the tag set must be present. (default "serve") + signature_key: Key identifying SignatureDef containing inputs and outputs. + batch_size: Batch size for the model. Replaces the first dimension of an + input size array if undefined. (default 1) + inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. + input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). + output_format: Type of data to write (currently must be TFLITE or + GRAPHVIZ_DOT) + quantized_input_stats: For each member of input_tensors the mean and + std deviation of training data. Only needed if `inference_type` is + `QUANTIZED_UINT8`. + drop_control_dependency: Drops control dependencies silently. This is due + to tf lite not supporting control dependencies. -if __name__ == "__main__": - app.run(main) + Returns: + The converted data. For example if tflite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + ValueError: Unable to convert to frozen graph. + """ + frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( + saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, + signature_key, batch_size) + + result = convert.toco_convert( + input_data=frozen_graph_def, + input_tensors=in_tensors, + output_tensors=out_tensors, + inference_type=inference_type, + input_format=input_format, + output_format=output_format, + quantized_input_stats=quantized_input_stats, + drop_control_dependency=drop_control_dependency) + + if output_file is not None: + with gfile.Open(output_file, "wb") as f: + f.write(result) + logging.info("Successfully converted to: %s", output_file) + + return result diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index 734e42d619..db95fc8ad7 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TF Lite SavedModel Conversion test cases. - - - test on generated saved_models from simple graphs (sanity check) - - test mnist savedmodel generated on-the-fly +"""TFLite SavedModel conversion test cases. + - Tests converting simple SavedModel graph to TFLite FlatBuffer. + - Tests converting simple SavedModel graph to frozen graph. + - Tests converting MNIST SavedModel to TFLite FlatBuffer. """ from __future__ import absolute_import @@ -25,6 +25,7 @@ from __future__ import print_function import os from tensorflow.contrib.lite.python import convert_saved_model +from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.estimator import estimator_lib as estimator @@ -37,6 +38,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model from tensorflow.python.training import training as train @@ -45,7 +47,7 @@ from tensorflow.python.training import training as train class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): def _createSimpleSavedModel(self, shape): - """Create a simple savedmodel on the fly.""" + """Create a simple SavedModel on the fly.""" saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") with session.Session() as sess: in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32) @@ -56,44 +58,78 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): return saved_model_dir def testSimpleSavedModel(self): - """Test a simple savedmodel created on the fly.""" - # Create a simple savedmodel + """Test a simple SavedModel created on the fly.""" + # Create a simple SavedModel saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) # Convert to tflite - result = convert_saved_model.convert(saved_model_dir=saved_model_dir) + result = convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir) self.assertTrue(result) def testSimpleSavedModelWithNoneBatchSizeInShape(self): - """Test a simple savedmodel, with None in input tensor's shape.""" + """Test a simple SavedModel, with None in input tensor's shape.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) - result = convert_saved_model.convert(saved_model_dir=saved_model_dir) + result = convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir) self.assertTrue(result) def testSimpleSavedModelWithMoreNoneInShape(self): - """Test a simple savedmodel, fail as more None in input shape.""" + """Test a simple SavedModel, fail as more None in input shape.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3]) # Convert to tflite: this should raise ValueError, as 3rd dim is None. with self.assertRaises(ValueError): - convert_saved_model.convert(saved_model_dir=saved_model_dir) + convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir) def testSimpleSavedModelWithWrongSignatureKey(self): - """Test a simple savedmodel, fail as given signature is invalid.""" + """Test a simple SavedModel, fail as given signature is invalid.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) # Convert to tflite: this should raise ValueError, as # signature_key does not exit in the saved_model. with self.assertRaises(ValueError): - convert_saved_model.convert( + convert_saved_model.tflite_from_saved_model( saved_model_dir=saved_model_dir, signature_key="wrong-key") def testSimpleSavedModelWithWrongOutputArray(self): - """Test a simple savedmodel, fail as given output_arrays is invalid.""" - # Create a simple savedmodel + """Test a simple SavedModel, fail as given output_arrays is invalid.""" + # Create a simple SavedModel saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) # Convert to tflite: this should raise ValueError, as # output_arrays is not valid for the saved_model. with self.assertRaises(ValueError): - convert_saved_model.convert( - saved_model_dir=saved_model_dir, output_arrays="wrong-output") + convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir, output_arrays=["wrong-output"]) + + def testSimpleSavedModelWithWrongInputArrays(self): + """Test a simple SavedModel, fail as given input_arrays is invalid.""" + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + # Checks invalid input_arrays. + with self.assertRaises(ValueError): + convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir, input_arrays=["wrong-input"]) + # Checks valid and invalid input_arrays. + with self.assertRaises(ValueError): + convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir, + input_arrays=["Placeholder", "wrong-input"]) + + def testSimpleSavedModelWithCorrectArrays(self): + """Test a simple SavedModel, with correct input_arrays and output_arrays.""" + saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) + result = convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir, + input_arrays=["Placeholder"], + output_arrays=["add"]) + self.assertTrue(result) + + def testSimpleSavedModelWithCorrectInputArrays(self): + """Test a simple SavedModel, with correct input_arrays and input_shapes.""" + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + result = convert_saved_model.tflite_from_saved_model( + saved_model_dir=saved_model_dir, + input_arrays=["Placeholder"], + input_shapes={"Placeholder": [1, 16, 16, 3]}) + self.assertTrue(result) def testMultipleMetaGraphDef(self): """Test saved model with multiple MetaGraphDef.""" @@ -119,20 +155,103 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): sess, tags=[saved_model.tag_constants.SERVING, "additional_test_tag"], signature_def_map=signature_def_map) + # MetaGraphDef 2 builder.add_meta_graph(tags=["tflite"]) builder.save(True) # Convert to tflite - convert_saved_model.convert( + convert_saved_model.tflite_from_saved_model( saved_model_dir=saved_model_dir, tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) +class ConvertSavedModelTestBasicGraphToText(test_util.TensorFlowTestCase): + + def _createSimpleSavedModel(self, shape): + """Create a simple SavedModel.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") + with session.Session() as sess: + in_tensor_1 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputB") + in_tensor_2 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputA") + out_tensor = in_tensor_1 + in_tensor_2 + inputs = {"x": in_tensor_1, "y": in_tensor_2} + outputs = {"z": out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def _getInputArrayNames(self, model_proto): + return [data.name for data in model_proto.input_arrays] + + def _getInputArrayShapes(self, model_proto): + return [ + [dim for dim in data.shape.dims] for data in model_proto.input_arrays + ] + + def _get_model_flags_proto_from_file(self, filename): + proto = _model_flags_pb2.ModelFlags() + with gfile.Open(filename, "rb") as output_file: + proto.ParseFromString(output_file.read()) + output_file.close() + return proto + + def testSimpleSavedModel(self): + """Test a simple SavedModel.""" + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + output_file_model = os.path.join(self.get_temp_dir(), "model.pb") + output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") + + convert_saved_model.saved_model_to_frozen_graphdef( + saved_model_dir=saved_model_dir, + output_file_model=output_file_model, + output_file_flags=output_file_flags, + input_arrays=["inputB", "inputA"]) + + proto = self._get_model_flags_proto_from_file(output_file_flags) + self.assertEqual(proto.output_arrays, ["add"]) + self.assertEqual(self._getInputArrayNames(proto), ["inputA", "inputB"]) + self.assertEqual( + self._getInputArrayShapes(proto), [[1, 16, 16, 3], [1, 16, 16, 3]]) + + def testSimpleSavedModelWithDifferentInputNames(self): + """Test a simple SavedModel.""" + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + output_file_model = os.path.join(self.get_temp_dir(), "model.pb") + output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") + + # Check case where input shape is given. + convert_saved_model.saved_model_to_frozen_graphdef( + saved_model_dir=saved_model_dir, + output_file_model=output_file_model, + output_file_flags=output_file_flags, + input_arrays=["inputA"], + input_shapes={"inputA": [1, 16, 16, 3]}) + + proto = self._get_model_flags_proto_from_file(output_file_flags) + self.assertEqual(proto.output_arrays, ["add"]) + self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) + self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) + + # Check case where input shape is None. + convert_saved_model.saved_model_to_frozen_graphdef( + saved_model_dir=saved_model_dir, + output_file_model=output_file_model, + output_file_flags=output_file_flags, + input_arrays=["inputA"], + input_shapes={"inputA": None}) + + proto = self._get_model_flags_proto_from_file(output_file_flags) + self.assertEqual(proto.output_arrays, ["add"]) + self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) + self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) + + class Model(keras.Model): """Model to recognize digits in the MNIST dataset. - Train and export savedmodel, used for testOnflyTrainMnistSavedModel + Train and export SavedModel, used for testOnflyTrainMnistSavedModel Network structure is equivalent to: https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py @@ -238,7 +357,7 @@ def dummy_input_fn(): class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): def testTrainedMnistSavedModel(self): - """Test mnist savedmodel, trained with dummy data and small steps.""" + """Test mnist SavedModel, trained with dummy data and small steps.""" # Build classifier classifier = estimator.Estimator( model_fn=model_fn, @@ -253,21 +372,20 @@ class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): "image": image, }) - # Export savedmodel + # Export SavedModel saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel") classifier.export_savedmodel(saved_model_dir, pred_input_fn) # Convert to tflite and test output saved_model_name = os.listdir(saved_model_dir)[0] saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name) - output_tflite = os.path.join(saved_model_dir, - saved_model_final_dir + ".lite") + output_file = os.path.join(saved_model_dir, saved_model_final_dir + ".lite") # TODO(zhixianyan): no need to limit output_arrays to `Softmax' # once b/74205001 fixed and argmax implemented in tflite. - result = convert_saved_model.convert( + result = convert_saved_model.tflite_from_saved_model( saved_model_dir=saved_model_final_dir, - output_arrays="Softmax", - output_tflite=output_tflite) + output_arrays=["Softmax"], + output_file=output_file) self.assertTrue(result) diff --git a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py new file mode 100644 index 0000000000..4d9782f4a6 --- /dev/null +++ b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py @@ -0,0 +1,106 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python console command for generating frozen models from SavedModels. + +This exists to add SavedModel compatibility to TOCO. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +from tensorflow.contrib.lite.python.convert_saved_model import saved_model_to_frozen_graphdef +from tensorflow.python.platform import app + +FLAGS = None + + +def execute(unused_args): + """Calls function to convert the SavedModel to a frozen graph.""" + # Error handling. + if FLAGS.input_shapes and not FLAGS.input_arrays: + raise ValueError("Input shapes requires input arrays to be specified.") + + # Calls saved_model_to_frozen_graphdef function to generate frozen graph. + input_arrays = (FLAGS.input_arrays.split(",") if FLAGS.input_arrays else None) + input_shapes = None + if FLAGS.input_shapes: + input_shapes = { + input_arrays[idx]: shape.split(",") + for idx, shape in enumerate(FLAGS.input_shapes.split(":")) + } + output_arrays = ( + FLAGS.output_arrays.split(",") if FLAGS.output_arrays else None) + tag_set = set(FLAGS.tag_set.split(",")) if FLAGS.tag_set else None + + saved_model_to_frozen_graphdef( + saved_model_dir=FLAGS.saved_model_directory, + output_file_model=FLAGS.output_file_model, + output_file_flags=FLAGS.output_file_flags, + input_arrays=input_arrays, + input_shapes=input_shapes, + output_arrays=output_arrays, + tag_set=tag_set, + signature_key=FLAGS.signature_key, + batch_size=FLAGS.batch_size) + + +def main(): + global FLAGS + # Parses flags. + parser = argparse.ArgumentParser( + description="Invoke SavedModel to frozen model converter.") + parser.add_argument( + "saved_model_directory", + type=str, + help="Full path to directory containing the SavedModel.") + parser.add_argument( + "output_file_model", + type=str, + help="Full file path to save frozen graph.") + parser.add_argument( + "output_file_flags", type=str, help="Full file path to save ModelFlags.") + parser.add_argument( + "--input_arrays", + type=str, + help="Name of the input arrays, comma-separated.") + parser.add_argument( + "--input_shapes", + type=str, + help="Shapes corresponding to --input_arrays, colon-separated.") + parser.add_argument( + "--output_arrays", + type=str, + help="Name of the output arrays, comma-separated.") + parser.add_argument( + "--tag_set", type=str, help="Name of output arrays, comma-separated.") + parser.add_argument( + "--signature_key", + type=str, + help="Key identifying SignatureDef containing inputs and outputs.") + parser.add_argument( + "--batch_size", + type=int, + help="Batch size for the model. Replaces the first dimension of an " + "input size array if undefined.") + + FLAGS, unparsed = parser.parse_known_args() + + app.run(main=execute, argv=[sys.argv[0]] + unparsed) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/convert_test.py index b8b4510188..dc21a9b669 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/convert_test.py @@ -17,8 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python import lite -from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base +from tensorflow.contrib.lite.python import convert +from tensorflow.contrib.lite.python import lite_constants +from tensorflow.contrib.lite.python import op_hint from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util @@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class LiteTest(test_util.TensorFlowTestCase): +class ConvertTest(test_util.TensorFlowTestCase): def testBasic(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], @@ -37,13 +38,13 @@ class LiteTest(test_util.TensorFlowTestCase): out_tensor = in_tensor + in_tensor sess = session.Session() # Try running on valid graph - result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) + result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) self.assertTrue(result) # TODO(aselle): remove tests that fail (we must get TOCO to not fatal # all the time). # Try running on identity graph (known fail) # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): - # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) + # result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) def testQuantization(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], @@ -51,13 +52,14 @@ class LiteTest(test_util.TensorFlowTestCase): out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor, min=0., max=1.) sess = session.Session() - result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor], - inference_type=lite.QUANTIZED_UINT8, - quantized_input_stats=[(0., 1.)]) + result = convert.toco_convert( + sess.graph_def, [in_tensor], [out_tensor], + inference_type=lite_constants.QUANTIZED_UINT8, + quantized_input_stats=[(0., 1.)]) self.assertTrue(result) -class LiteTestOpHint(test_util.TensorFlowTestCase): +class ConvertTestOpHint(test_util.TensorFlowTestCase): """Test the hint to stub functionality.""" def _getGraphOpTypes(self, graphdef, output_nodes): @@ -99,7 +101,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase): swish_scale = array_ops.constant(1.0) def _swish(input_tensor, scale): - custom = lite.OpHint("cool_activation") + custom = op_hint.OpHint("cool_activation") input_tensor, scale = custom.add_inputs(input_tensor, scale) output = math_ops.sigmoid(input_tensor) * input_tensor * scale output, = custom.add_outputs(output) @@ -111,11 +113,12 @@ class LiteTestOpHint(test_util.TensorFlowTestCase): # and 1 final output). self.assertEqual(self._countIdentities(sess.graph_def.node), 4) - stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) self.assertCountEqual( self._getGraphOpTypes( - stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + stubbed_graphdef, + output_nodes=[op_hint._tensor_name_base(output)]), ["cool_activation", "Const", "Identity"]) def testScaleAndBiasAndIdentity(self): @@ -125,7 +128,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase): b = array_ops.constant([4., 5.]) def _scaled_and_bias_and_identity(a, x, b): - custom = lite.OpHint("scale_and_bias_and_identity") + custom = op_hint.OpHint("scale_and_bias_and_identity") a, x, b = custom.add_inputs(a, x, b) return custom.add_outputs(a * x + b, x) output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b), @@ -136,11 +139,12 @@ class LiteTestOpHint(test_util.TensorFlowTestCase): # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 6) - stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) self.assertCountEqual( self._getGraphOpTypes( - stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + stubbed_graphdef, + output_nodes=[op_hint._tensor_name_base(output)]), ["scale_and_bias_and_identity", "Const", "Identity", "Pack"]) def testTwoFunctions(self): @@ -148,7 +152,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase): a = array_ops.constant([1.]) b = array_ops.constant([1.]) def _double_values(x): - custom = lite.OpHint("add_test") + custom = op_hint.OpHint("add_test") x = custom.add_inputs(x) output = math_ops.multiply(x, x) output, = custom.add_outputs(output) @@ -160,10 +164,11 @@ class LiteTestOpHint(test_util.TensorFlowTestCase): # make sure one identity for each input (2) and output (2) => 2 + 2 # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 5) - stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) self.assertCountEqual( self._getGraphOpTypes( - stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + stubbed_graphdef, + output_nodes=[op_hint._tensor_name_base(output)]), ["add_test", "Const", "Identity", "Add"]) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index cf50f9d4d6..4ea40201f7 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -18,6 +18,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. @@toco_convert @@toco_convert_protos +@@tflite_from_saved_model @@OpHint @@convert_op_hints_to_stubs @@ -25,208 +26,11 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os as _os -import subprocess as _subprocess -import tempfile as _tempfile # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert import toco_convert +from tensorflow.contrib.lite.python.convert import toco_convert_protos +from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: enable=unused-import -from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 -from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 -from tensorflow.python.framework import dtypes as _dtypes -from tensorflow.python.platform import resource_loader as _resource_loader -from tensorflow.python.util.all_util import remove_undocumented -from tensorflow.python.util.lazy_loader import LazyLoader - -# Lazy load since some of the performance benchmark skylark rules -# break dependencies. -_toco_python = LazyLoader( - "tensorflow_wrap_toco", globals(), - "tensorflow.contrib.lite.toco.python." - "tensorflow_wrap_toco") -del LazyLoader - -# Enum types from the protobuf promoted to the API -FLOAT = _types_pb2.FLOAT -INT32 = _types_pb2.INT32 -INT64 = _types_pb2.INT64 -STRING = _types_pb2.STRING -QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8 -TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF -TFLITE = _toco_flags_pb2.TFLITE -GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT - -# Currently the default mode of operation is to shell to another python process -# to protect against crashes. However, it breaks some dependent targets because -# it forces us to depend on an external py_binary. The experimental API doesn't -# have that drawback. -EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False - -# Find the toco_from_protos binary using the resource loader if using from -# bazel, otherwise we are in a pip where console_scripts already has -# the toco_from_protos tool. -if EXPERIMENTAL_USE_TOCO_API_DIRECTLY: - _toco_from_proto_bin = "" -else: - _toco_from_proto_bin = _resource_loader.get_path_to_datafile( - "../toco/python/toco_from_protos") - -if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin): - _toco_from_proto_bin = "toco_from_protos" - - -def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): - """Convert `input_data_str` according to model and toco parameters. - - Unless you know what you are doing consider using - the more friendly @{tf.contrib.lite.toco_convert}}. - - Args: - model_flags_str: Serialized proto describing model properties, see - `toco/model_flags.proto`. - toco_flags_str: Serialized proto describing conversion properties, see - `toco/toco_flags.proto`. - input_data_str: Input data in serialized form (e.g. a graphdef is common) - Returns: - Converted model in serialized form (e.g. a TFLITE model is common). - Raises: - RuntimeError: When conversion fails, an exception is raised with the error - message embedded. - """ - # TODO(aselle): When toco does not use fatal errors for failure, we can - # switch this on. - if not _toco_from_proto_bin: - return _toco_python.TocoConvert( - model_flags_str, toco_flags_str, input_data_str) - - with _tempfile.NamedTemporaryFile() as fp_toco, \ - _tempfile.NamedTemporaryFile() as fp_model, \ - _tempfile.NamedTemporaryFile() as fp_input, \ - _tempfile.NamedTemporaryFile() as fp_output: - fp_model.write(model_flags_str) - fp_toco.write(toco_flags_str) - fp_input.write(input_data_str) - fp_model.flush() - fp_toco.flush() - fp_input.flush() - - cmd = [ - _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name, - fp_output.name - ] - cmdline = " ".join(cmd) - proc = _subprocess.Popen( - cmdline, - shell=True, - stdout=_subprocess.PIPE, - stderr=_subprocess.STDOUT, - close_fds=True) - stdout, stderr = proc.communicate() - exitcode = proc.returncode - if exitcode == 0: - stuff = fp_output.read() - return stuff - else: - raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" % - (stdout, stderr)) - - -def _tensor_name(x): - return x.name.split(":")[0] - - -def toco_convert(input_data, - input_tensors, - output_tensors, - inference_type=FLOAT, - input_format=TENSORFLOW_GRAPHDEF, - output_format=TFLITE, - quantized_input_stats=None, - drop_control_dependency=True, - allow_custom_ops=None): - """Convert a model using TOCO from `input_format` to `output_format`. - - Typically this is to convert from TensorFlow GraphDef to TFLite, in which - case the default `input_format` and `output_format` are sufficient. - - Args: - input_data: Input data (i.e. often `sess.graph_def`). - input_tensors: List of input tensors. Type and shape are computed using - `foo.get_shape()` and `foo.dtype`. - output_tensors: List of output tensors (only .name is used from this). - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. - - Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. - - Raises: - ValueError: If the input tensor type is unknown - RuntimeError: If TOCO fails to convert (in which case the runtime error's - error text will contain the TOCO error log) - """ - toco = _toco_flags_pb2.TocoFlags() - toco.input_format = input_format - toco.output_format = output_format - toco.inference_type = inference_type - toco.drop_control_dependency = drop_control_dependency - if allow_custom_ops is not None: - toco.allow_custom_ops = allow_custom_ops - - model = _model_flags_pb2.ModelFlags() - for idx, input_tensor in enumerate(input_tensors): - if input_tensor.dtype == _dtypes.float32: - tflite_input_type = FLOAT - elif input_tensor.dtype == _dtypes.int32: - tflite_input_type = INT32 - elif input_tensor.dtype == _dtypes.int64: - tflite_input_type = INT64 - # TODO(aselle): Insert strings when they are available - else: - raise ValueError("Tensors %s not known type %r" % (input_tensor.name, - input_tensor.dtype)) - - input_array = model.input_arrays.add() - - if inference_type == QUANTIZED_UINT8: - if tflite_input_type == FLOAT: - tflite_input_type = QUANTIZED_UINT8 - input_array.mean_value, input_array.std_value = quantized_input_stats[idx] - - input_array.name = _tensor_name(input_tensor) - input_array.shape.dims.extend(map(int, input_tensor.get_shape())) - - for output_tensor in output_tensors: - model.output_arrays.append(_tensor_name(output_tensor)) - - # TODO(aselle): Consider handling the case of allowing quantized - # inputs to be converted to float (via the toco.inference_input_type field). - data = toco_convert_protos(model.SerializeToString(), - toco.SerializeToString(), - input_data.SerializeToString()) - return data - - -_allowed_symbols = [ - "FLOAT", - "INT32", - "INT64", - "STRING", - "QUANTIZED_UINT8", - "TENSORFLOW_GRAPHDEF", - "TFLITE", - "GRAPHVIZ_DOT", - "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", -] -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/lite/python/lite_constants.py b/tensorflow/contrib/lite/python/lite_constants.py new file mode 100644 index 0000000000..195d7a732f --- /dev/null +++ b/tensorflow/contrib/lite/python/lite_constants.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants for TFLite.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.util.all_util import remove_undocumented + +# Enum types from the protobuf promoted to the API +FLOAT = _types_pb2.FLOAT +INT32 = _types_pb2.INT32 +INT64 = _types_pb2.INT64 +STRING = _types_pb2.STRING +QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8 +TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF +TFLITE = _toco_flags_pb2.TFLITE +GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT + +# Currently the default mode of operation is to shell to another python process +# to protect against crashes. However, it breaks some dependent targets because +# it forces us to depend on an external py_binary. The experimental API doesn't +# have that drawback. +EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False + + +_allowed_symbols = [ + "FLOAT", + "INT32", + "INT64", + "STRING", + "QUANTIZED_UINT8", + "TENSORFLOW_GRAPHDEF", + "TFLITE", + "GRAPHVIZ_DOT", + "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", +] +remove_undocumented(__name__, _allowed_symbols) |