aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2017-01-31 20:51:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 21:12:58 -0800
commit57f609d07fc9d11fef9cf415d8ebe902bee5d89f (patch)
tree725ad377e4f4554b4ca9997859b659bad6966a29
parent744e4070c503dccd07194e389c920b1f8167ca0d (diff)
Allow passing arguments through saved_model loader to saver.import_meta_graph.
These **kwargs are ultimately passed to meta_graph.import_scoped_meta_graph. This CL allows setting useful loading options such as import_scope and input_map. Change: 146201221
-rw-r--r--tensorflow/python/saved_model/loader_impl.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index 77c6597665..86f59d6805 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -175,7 +175,7 @@ def maybe_saved_model_directory(export_dir):
return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
-def load(sess, tags, export_dir):
+def load(sess, tags, export_dir, **saver_kwargs):
"""Loads the model from a SavedModel as specified by tags.
Args:
@@ -185,6 +185,7 @@ def load(sess, tags, export_dir):
SavedModel `save()` API.
export_dir: Directory in which the SavedModel protocol buffer and variables
to be loaded are located.
+ **saver_kwargs: Optional keyword arguments passed through to Saver.
Returns:
The `MetaGraphDef` protocol buffer loaded in the provided session. This
@@ -207,7 +208,7 @@ def load(sess, tags, export_dir):
"[]") + " could not be found in SavedModel")
# Build a saver by importing the meta graph def to load.
- saver = tf_saver.import_meta_graph(meta_graph_def_to_load)
+ saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
# Build the checkpoint path where the variables are located.
variables_path = os.path.join(