aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-06-07 10:38:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 10:41:48 -0700
commitcd25a9544915654022e2cfff4923c31822166112 (patch)
tree3d85e5e728df88547bf997c47dd1e1224fa6e02e /tensorflow/contrib/lite/python
parent796fff865013f964e85c134dddf6f1f49574bd72 (diff)
Updated SavedModels in Python TOCO API.
PiperOrigin-RevId: 199658431
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py31
-rw-r--r--tensorflow/contrib/lite/python/lite.py2
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py2
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py2
5 files changed, 15 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 7e6ff6c0a8..27909a9458 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -57,8 +57,9 @@ py_library(
":interpreter",
":lite_constants",
":op_hint",
- "//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
+ "//tensorflow/python/saved_model:constants",
+ "//tensorflow/python/saved_model:loader",
"//tensorflow/python/tools:freeze_graph_lib",
],
)
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 5dad49f1ed..1553464b9f 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -19,13 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.lite.python.convert import tensor_name
-from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
@@ -58,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set):
Raises:
ValueError: No valid MetaGraphDef for given tag_set.
"""
- saved_model = reader.read_saved_model(saved_model_dir)
- tag_sets = []
- result_meta_graph_def = None
- for meta_graph_def in saved_model.meta_graphs:
- meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags)
- tag_sets.append(meta_graph_tag_set)
- if meta_graph_tag_set == tag_set:
- result_meta_graph_def = meta_graph_def
- logging.info("The given saved_model contains the following tags: %s",
- tag_sets)
- if result_meta_graph_def is not None:
- return result_meta_graph_def
- else:
- raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible "
- "values are '{}'. ".format(tag_set, tag_sets))
+ with session.Session(graph=ops.Graph()) as sess:
+ return loader.load(sess, tag_set, saved_model_dir)
def _get_signature_def(meta_graph, signature_key):
@@ -97,9 +83,7 @@ def _get_signature_def(meta_graph, signature_key):
raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
"values are '{}'.".format(signature_key,
",".join(signature_def_keys)))
- signature_def = signature_def_utils.get_signature_def_by_key(
- meta_graph, signature_key)
- return signature_def
+ return signature_def_map[signature_key]
def _get_inputs_outputs(signature_def):
@@ -247,6 +231,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
+ assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
@@ -255,9 +240,13 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_def = _get_signature_def(meta_graph, signature_key)
inputs, outputs = _get_inputs_outputs(signature_def)
+ # Check SavedModel for assets directory.
+ collection_def = meta_graph.collection_def
+ if constants.ASSETS_KEY in collection_def:
+ raise ValueError("SavedModels with assets/ directory are not supported.")
+
graph = ops.Graph()
with session.Session(graph=graph) as sess:
- # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory.
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 253e3f72b1..e3a2d19e05 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -207,7 +207,7 @@ class TocoConverter(object):
# Check if graph is frozen.
if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py")
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
# Create TocoConverter class.
return cls(sess.graph_def, input_tensors, output_tensors)
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index bbb00021f9..b04caaf263 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -401,7 +401,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as error:
lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
['add'])
- self.assertEqual('Please freeze the graph using freeze_graph.py',
+ self.assertEqual('Please freeze the graph using freeze_graph.py.',
str(error.exception))
def testPbtxt(self):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 2b7ad29a27..4c215b62b2 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -114,7 +114,7 @@ def _convert_model(flags):
"--input_arrays must be present when specifying "
"--std_dev_values and --mean_values with multiple input "
"tensors in order to map between names and "
- "values".format(",".join(input_arrays)))
+ "values.".format(",".join(input_arrays)))
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
if flags.default_ranges_min and flags.default_ranges_max:
converter.default_ranges_stats = (flags.default_ranges_min,