aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/convert_saved_model.py
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-05-24 10:53:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 10:57:39 -0700
commitd9b764d72aa8e1f7959c396762d2054ee9d87cab (patch)
tree0ccd4e152d78c86f276dfe19f243b1d7a9a618de /tensorflow/contrib/lite/python/convert_saved_model.py
parentf286fb4557ab48f38882bc643ccc9a2c85677c63 (diff)
Improve TOCO Python API.
PiperOrigin-RevId: 197918102
Diffstat (limited to 'tensorflow/contrib/lite/python/convert_saved_model.py')
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py162
1 files changed, 7 insertions, 155 deletions
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index a7eddf3408..54fec9d61f 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.lite.python import convert
-from tensorflow.contrib.lite.python import lite_constants
-from tensorflow.contrib.lite.toco import model_flags_pb2
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
@@ -110,12 +107,12 @@ def _get_signature_def(meta_graph, signature_key):
signature_def_map = meta_graph.signature_def
signature_def_keys = set(signature_def_map.keys())
logging.info(
- "The given saved_model MetaGraphDef contains SignatureDefs with the "
+ "The given SavedModel MetaGraphDef contains SignatureDefs with the "
"following keys: %s", signature_def_keys)
if signature_key not in signature_def_keys:
- raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible "
- "values are '{}'. ".format(signature_key,
- signature_def_keys))
+ raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
+ "values are '{}'.".format(signature_key,
+ ",".join(signature_def_keys)))
signature_def = signature_def_utils.get_signature_def_by_key(
meta_graph, signature_key)
return signature_def
@@ -207,8 +204,8 @@ def _get_tensors(graph, signature_def_tensor_names=None,
return tensors
-def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
- output_arrays, tag_set, signature_key, batch_size):
+def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
+ output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
Args:
@@ -224,8 +221,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default "serve")
signature_key: Key identifying SignatureDef containing inputs and outputs.
- batch_size: Batch size for the model. Replaces the first dimension of an
- input size array if undefined. (default 1)
Returns:
frozen_graph_def: Frozen GraphDef.
@@ -237,7 +232,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
- input_shapes has a None value after the 1st dimension.
input_arrays or output_arrays are not valid.
Unable to load Session.
"""
@@ -246,8 +240,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
if tag_set is None:
tag_set = set([tag_constants.SERVING])
- if batch_size is None:
- batch_size = 1
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
@@ -264,23 +256,13 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
- # Gets fully defined tensor shape. An input tensor with None in the first
- # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size.
- # Shapes with None after the first dimension result in a ValueError.
- # TODO(zhixianyan): Add supports for input tensor with more None in shape.
+ # Gets fully defined tensor shape.
for tensor in in_tensors:
if (input_shapes and tensor.name in input_shapes and
input_shapes[tensor.name] is not None):
shape = input_shapes[tensor.name]
else:
shape = tensor.get_shape().as_list()
-
- if None in shape[1:]:
- raise ValueError(
- "None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(tensor.name, shape))
- elif shape[0] is None:
- shape[0] = batch_size
tensor.set_shape(shape)
output_names = [node.split(":")[0] for node in outputs]
@@ -289,133 +271,3 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
return frozen_graph_def, in_tensors, out_tensors
raise ValueError("Unable to load Session.")
-
-
-def saved_model_to_frozen_graphdef(
- saved_model_dir,
- output_file_model,
- output_file_flags,
- input_arrays=None,
- input_shapes=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- batch_size=1):
- """Converts a SavedModel to a frozen graph. Writes graph to tmp directory.
-
- Stores frozen graph and command line flags in the tmp directory.
-
- Args:
- saved_model_dir: SavedModel directory to convert.
- output_file_model: Full file path to save frozen graph.
- output_file_flags: Full file path to save ModelFlags.
- input_arrays: List of input tensors to freeze graph with. Uses input arrays
- from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
- integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
- Automatically determined when input shapes is None (e.g., {"foo" : None}).
- (default None)
- output_arrays: List of output tensors to freeze graph with. Uses output
- arrays from SignatureDef when none are provided. (default None)
- tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
- signature_key: Key identifying SignatureDef containing inputs and outputs.
- batch_size: Batch size for the model. Replaces the first dimension of an
- input size array if undefined. (default 1)
-
- Returns: None.
-
- Raises:
- ValueError: Unable to convert to frozen graph.
- """
- frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
- saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
- signature_key, batch_size)
-
- # Initialize model flags.
- model = model_flags_pb2.ModelFlags()
-
- for input_tensor in in_tensors:
- input_array = model.input_arrays.add()
- input_array.name = convert.tensor_name(input_tensor)
- input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
-
- for output_tensor in out_tensors:
- model.output_arrays.append(convert.tensor_name(output_tensor))
-
- # Write model and ModelFlags to file. ModelFlags contain input array and
- # output array information that is parsed from the SignatureDef and used for
- # analysis by TOCO.
- _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString())
- _write_and_flush_file(output_file_flags, model.SerializeToString())
-
-
-def tflite_from_saved_model(
- saved_model_dir,
- output_file=None,
- input_arrays=None,
- input_shapes=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- batch_size=1,
- inference_type=lite_constants.FLOAT,
- input_format=lite_constants.TENSORFLOW_GRAPHDEF,
- output_format=lite_constants.TFLITE,
- quantized_input_stats=None,
- drop_control_dependency=True):
- """Converts a SavedModel to TFLite FlatBuffer.
-
- Args:
- saved_model_dir: SavedModel directory to convert.
- output_file: File path to write result TFLite FlatBuffer.
- input_arrays: List of input tensors to freeze graph with. Uses input arrays
- from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
- integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
- Automatically determined when input shapes is None (e.g., {"foo" : None}).
- (default None)
- output_arrays: List of output tensors to freeze graph with. Uses output
- arrays from SignatureDef when none are provided. (default None)
- tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
- signature_key: Key identifying SignatureDef containing inputs and outputs.
- batch_size: Batch size for the model. Replaces the first dimension of an
- input size array if undefined. (default 1)
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT)
- quantized_input_stats: For each member of input_tensors the mean and
- std deviation of training data. Only needed if `inference_type` is
- `QUANTIZED_UINT8`.
- drop_control_dependency: Drops control dependencies silently. This is due
- to tf lite not supporting control dependencies.
-
- Returns:
- The converted data. For example if tflite was the destination, then
- this will be a tflite flatbuffer in a bytes array.
-
- Raises:
- ValueError: Unable to convert to frozen graph.
- """
- frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
- saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
- signature_key, batch_size)
-
- result = convert.toco_convert(
- input_data=frozen_graph_def,
- input_tensors=in_tensors,
- output_tensors=out_tensors,
- inference_type=inference_type,
- input_format=input_format,
- output_format=output_format,
- quantized_input_stats=quantized_input_stats,
- drop_control_dependency=drop_control_dependency)
-
- if output_file is not None:
- with gfile.Open(output_file, "wb") as f:
- f.write(result)
- logging.info("Successfully converted to: %s", output_file)
-
- return result