diff options
author | 2017-01-31 20:51:01 -0800 | |
---|---|---|
committer | 2017-01-31 21:12:58 -0800 | |
commit | 57f609d07fc9d11fef9cf415d8ebe902bee5d89f (patch) | |
tree | 725ad377e4f4554b4ca9997859b659bad6966a29 | |
parent | 744e4070c503dccd07194e389c920b1f8167ca0d (diff) |
Allow passing arguments through saved_model loader to saver.import_meta_graph.
These **kwargs are ultimately passed to meta_graph.import_scoped_meta_graph. This CL allows setting useful loading options such as import_scope and input_map.
Change: 146201221
-rw-r--r-- | tensorflow/python/saved_model/loader_impl.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 77c6597665..86f59d6805 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -175,7 +175,7 @@ def maybe_saved_model_directory(export_dir): return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) -def load(sess, tags, export_dir): +def load(sess, tags, export_dir, **saver_kwargs): """Loads the model from a SavedModel as specified by tags. Args: @@ -185,6 +185,7 @@ def load(sess, tags, export_dir): SavedModel `save()` API. export_dir: Directory in which the SavedModel protocol buffer and variables to be loaded are located. + **saver_kwargs: Optional keyword arguments passed through to Saver. Returns: The `MetaGraphDef` protocol buffer loaded in the provided session. This @@ -207,7 +208,7 @@ def load(sess, tags, export_dir): "[]") + " could not be found in SavedModel") # Build a saver by importing the meta graph def to load. - saver = tf_saver.import_meta_graph(meta_graph_def_to_load) + saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs) # Build the checkpoint path where the variables are located. variables_path = os.path.join( |