aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/convert_saved_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/python/convert_saved_model.py')
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py31
1 files changed, 10 insertions, 21 deletions
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 5dad49f1ed..1553464b9f 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -19,13 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.lite.python.convert import tensor_name
-from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
@@ -58,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set):
Raises:
ValueError: No valid MetaGraphDef for given tag_set.
"""
- saved_model = reader.read_saved_model(saved_model_dir)
- tag_sets = []
- result_meta_graph_def = None
- for meta_graph_def in saved_model.meta_graphs:
- meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags)
- tag_sets.append(meta_graph_tag_set)
- if meta_graph_tag_set == tag_set:
- result_meta_graph_def = meta_graph_def
- logging.info("The given saved_model contains the following tags: %s",
- tag_sets)
- if result_meta_graph_def is not None:
- return result_meta_graph_def
- else:
- raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible "
- "values are '{}'. ".format(tag_set, tag_sets))
+ with session.Session(graph=ops.Graph()) as sess:
+ return loader.load(sess, tag_set, saved_model_dir)
def _get_signature_def(meta_graph, signature_key):
@@ -97,9 +83,7 @@ def _get_signature_def(meta_graph, signature_key):
raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
"values are '{}'.".format(signature_key,
",".join(signature_def_keys)))
- signature_def = signature_def_utils.get_signature_def_by_key(
- meta_graph, signature_key)
- return signature_def
+ return signature_def_map[signature_key]
def _get_inputs_outputs(signature_def):
@@ -247,6 +231,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
+ assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
@@ -255,9 +240,13 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_def = _get_signature_def(meta_graph, signature_key)
inputs, outputs = _get_inputs_outputs(signature_def)
+ # Check SavedModel for assets directory.
+ collection_def = meta_graph.collection_def
+ if constants.ASSETS_KEY in collection_def:
+ raise ValueError("SavedModels with assets/ directory are not supported.")
+
graph = ops.Graph()
with session.Session(graph=graph) as sess:
- # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory.
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.