diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-06-07 10:38:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-07 10:41:48 -0700 |
commit | cd25a9544915654022e2cfff4923c31822166112 (patch) | |
tree | 3d85e5e728df88547bf997c47dd1e1224fa6e02e /tensorflow/contrib/lite/python | |
parent | 796fff865013f964e85c134dddf6f1f49574bd72 (diff) |
Updated SavedModels in Python TOCO API.
PiperOrigin-RevId: 199658431
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_saved_model.py | 31 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/tflite_convert.py | 2 |
5 files changed, 15 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 7e6ff6c0a8..27909a9458 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -57,8 +57,9 @@ py_library( ":interpreter", ":lite_constants", ":op_hint", - "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", "//tensorflow/python/tools:freeze_graph_lib", ], ) diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index 5dad49f1ed..1553464b9f 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -19,13 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.lite.python.convert import tensor_name -from tensorflow.contrib.saved_model.python.saved_model import reader -from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader @@ -58,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set): Raises: ValueError: No valid MetaGraphDef for given tag_set. """ - saved_model = reader.read_saved_model(saved_model_dir) - tag_sets = [] - result_meta_graph_def = None - for meta_graph_def in saved_model.meta_graphs: - meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) - tag_sets.append(meta_graph_tag_set) - if meta_graph_tag_set == tag_set: - result_meta_graph_def = meta_graph_def - logging.info("The given saved_model contains the following tags: %s", - tag_sets) - if result_meta_graph_def is not None: - return result_meta_graph_def - else: - raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " - "values are '{}'. ".format(tag_set, tag_sets)) + with session.Session(graph=ops.Graph()) as sess: + return loader.load(sess, tag_set, saved_model_dir) def _get_signature_def(meta_graph, signature_key): @@ -97,9 +83,7 @@ def _get_signature_def(meta_graph, signature_key): raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible " "values are '{}'.".format(signature_key, ",".join(signature_def_keys))) - signature_def = signature_def_utils.get_signature_def_by_key( - meta_graph, signature_key) - return signature_def + return signature_def_map[signature_key] def _get_inputs_outputs(signature_def): @@ -247,6 +231,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, ValueError: SavedModel doesn't contain a MetaGraphDef identified by tag_set. signature_key is not in the MetaGraphDef. + assets/ directory is in the MetaGraphDef. input_shapes does not match the length of input_arrays. input_arrays or output_arrays are not valid. """ @@ -255,9 +240,13 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, signature_def = _get_signature_def(meta_graph, signature_key) inputs, outputs = _get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + graph = ops.Graph() with session.Session(graph=graph) as sess: - # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory. loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) # Gets input and output tensors. diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 253e3f72b1..e3a2d19e05 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -207,7 +207,7 @@ class TocoConverter(object): # Check if graph is frozen. if not _is_frozen_graph(sess): - raise ValueError("Please freeze the graph using freeze_graph.py") + raise ValueError("Please freeze the graph using freeze_graph.py.") # Create TocoConverter class. return cls(sess.graph_def, input_tensors, output_tensors) diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index bbb00021f9..b04caaf263 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -401,7 +401,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as error: lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], ['add']) - self.assertEqual('Please freeze the graph using freeze_graph.py', + self.assertEqual('Please freeze the graph using freeze_graph.py.', str(error.exception)) def testPbtxt(self): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 2b7ad29a27..4c215b62b2 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -114,7 +114,7 @@ def _convert_model(flags): "--input_arrays must be present when specifying " "--std_dev_values and --mean_values with multiple input " "tensors in order to map between names and " - "values".format(",".join(input_arrays))) + "values.".format(",".join(input_arrays))) converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) if flags.default_ranges_min and flags.default_ranges_max: converter.default_ranges_stats = (flags.default_ranges_min, |