diff options
author | 2018-05-24 10:53:28 -0700 | |
---|---|---|
committer | 2018-05-24 10:57:39 -0700 | |
commit | d9b764d72aa8e1f7959c396762d2054ee9d87cab (patch) | |
tree | 0ccd4e152d78c86f276dfe19f243b1d7a9a618de /tensorflow/contrib/lite/python/convert_saved_model.py | |
parent | f286fb4557ab48f38882bc643ccc9a2c85677c63 (diff) |
Improve TOCO Python API.
PiperOrigin-RevId: 197918102
Diffstat (limited to 'tensorflow/contrib/lite/python/convert_saved_model.py')
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model.py | 162 |
1 files changed, 7 insertions, 155 deletions
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index a7eddf3408..54fec9d61f 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python import convert -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.toco import model_flags_pb2 from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 @@ -110,12 +107,12 @@ def _get_signature_def(meta_graph, signature_key): signature_def_map = meta_graph.signature_def signature_def_keys = set(signature_def_map.keys()) logging.info( - "The given saved_model MetaGraphDef contains SignatureDefs with the " + "The given SavedModel MetaGraphDef contains SignatureDefs with the " "following keys: %s", signature_def_keys) if signature_key not in signature_def_keys: - raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible " - "values are '{}'. ".format(signature_key, - signature_def_keys)) + raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible " + "values are '{}'.".format(signature_key, + ",".join(signature_def_keys))) signature_def = signature_def_utils.get_signature_def_by_key( meta_graph, signature_key) return signature_def @@ -207,8 +204,8 @@ def _get_tensors(graph, signature_def_tensor_names=None, return tensors -def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, - output_arrays, tag_set, signature_key, batch_size): +def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key): """Converts a SavedModel to a frozen graph. Args: @@ -224,8 +221,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be present. (default "serve") signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) Returns: frozen_graph_def: Frozen GraphDef. @@ -237,7 +232,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, SavedModel doesn't contain a MetaGraphDef identified by tag_set. signature_key is not in the MetaGraphDef. input_shapes does not match the length of input_arrays. - input_shapes has a None value after the 1st dimension. input_arrays or output_arrays are not valid. Unable to load Session. """ @@ -246,8 +240,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY if tag_set is None: tag_set = set([tag_constants.SERVING]) - if batch_size is None: - batch_size = 1 # Read SignatureDef. meta_graph = _get_meta_graph_def(saved_model_dir, tag_set) @@ -264,23 +256,13 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, in_tensors = _get_tensors(graph, inputs, input_arrays) out_tensors = _get_tensors(graph, outputs, output_arrays) - # Gets fully defined tensor shape. An input tensor with None in the first - # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size. - # Shapes with None after the first dimension result in a ValueError. - # TODO(zhixianyan): Add supports for input tensor with more None in shape. + # Gets fully defined tensor shape. for tensor in in_tensors: if (input_shapes and tensor.name in input_shapes and input_shapes[tensor.name] is not None): shape = input_shapes[tensor.name] else: shape = tensor.get_shape().as_list() - - if None in shape[1:]: - raise ValueError( - "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(tensor.name, shape)) - elif shape[0] is None: - shape[0] = batch_size tensor.set_shape(shape) output_names = [node.split(":")[0] for node in outputs] @@ -289,133 +271,3 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, return frozen_graph_def, in_tensors, out_tensors raise ValueError("Unable to load Session.") - - -def saved_model_to_frozen_graphdef( - saved_model_dir, - output_file_model, - output_file_flags, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1): - """Converts a SavedModel to a frozen graph. Writes graph to tmp directory. - - Stores frozen graph and command line flags in the tmp directory. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file_model: Full file path to save frozen graph. - output_file_flags: Full file path to save ModelFlags. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - - Returns: None. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - # Initialize model flags. - model = model_flags_pb2.ModelFlags() - - for input_tensor in in_tensors: - input_array = model.input_arrays.add() - input_array.name = convert.tensor_name(input_tensor) - input_array.shape.dims.extend(map(int, input_tensor.get_shape())) - - for output_tensor in out_tensors: - model.output_arrays.append(convert.tensor_name(output_tensor)) - - # Write model and ModelFlags to file. ModelFlags contain input array and - # output array information that is parsed from the SignatureDef and used for - # analysis by TOCO. - _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString()) - _write_and_flush_file(output_file_flags, model.SerializeToString()) - - -def tflite_from_saved_model( - saved_model_dir, - output_file=None, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1, - inference_type=lite_constants.FLOAT, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - drop_control_dependency=True): - """Converts a SavedModel to TFLite FlatBuffer. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file: File path to write result TFLite FlatBuffer. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. - - Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - result = convert.toco_convert( - input_data=frozen_graph_def, - input_tensors=in_tensors, - output_tensors=out_tensors, - inference_type=inference_type, - input_format=input_format, - output_format=output_format, - quantized_input_stats=quantized_input_stats, - drop_control_dependency=drop_control_dependency) - - if output_file is not None: - with gfile.Open(output_file, "wb") as f: - f.write(result) - logging.info("Successfully converted to: %s", output_file) - - return result |