aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-07-20 15:45:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 15:51:41 -0700
commit6c528feaf820bdde820833ad24e05167adb5daa7 (patch)
treecdffac07b9e343e03958b734ac9553102bbd4ccf /tensorflow/python/saved_model
parent5e876a8c25819070d78aa96595943afa207a6671 (diff)
Automated rollback of commit 8257891f378027a1a7c0403ba6ba0aeb313496a0
PiperOrigin-RevId: 205466000
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/loader_impl.py13
-rw-r--r--tensorflow/python/saved_model/loader_test.py19
2 files changed, 8 insertions, 24 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index 685a913f9c..e5f649fdab 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -284,15 +284,12 @@ class SavedModelLoader(object):
**saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
Returns:
- A tuple of
- * Saver defined by the MetaGraph, which can be used to restore the
- variable values.
- * List of `Operation`/`Tensor` objects returned from
- `tf.import_graph_def` (may be `None`).
+ Saver defined by the MetaGraph, which can be used to restore the variable
+ values.
"""
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
with graph.as_default():
- return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access
+ return tf_saver.import_meta_graph(
meta_graph_def, import_scope=import_scope, **saver_kwargs)
def restore_variables(self, sess, saver, import_scope=None):
@@ -364,8 +361,8 @@ class SavedModelLoader(object):
`MetagraphDef` proto of the graph that was loaded.
"""
with sess.graph.as_default():
- saver, _ = self.load_graph(sess.graph, tags, import_scope,
- **saver_kwargs)
+ saver = self.load_graph(sess.graph, tags, import_scope,
+ **saver_kwargs)
self.restore_variables(sess, saver, import_scope)
self.run_init_ops(sess, tags, import_scope)
return self.get_meta_graph_def_from_tags(tags)
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index 9a0b276a4b..ce18859f6b 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -111,8 +111,7 @@ class SavedModelLoaderTest(test.TestCase):
def test_load_with_import_scope(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.test_session(graph=ops.Graph()) as sess:
- saver, _ = loader.load_graph(
- sess.graph, ["foo_graph"], import_scope="baz")
+ saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz")
# The default saver should not work when the import scope is set.
with self.assertRaises(errors.NotFoundError):
@@ -150,7 +149,7 @@ class SavedModelLoaderTest(test.TestCase):
def test_run_init_op(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
graph = ops.Graph()
- saver, _ = loader.load_graph(graph, ["foo_graph"])
+ saver = loader.load_graph(graph, ["foo_graph"])
with self.test_session(graph=graph) as sess:
loader.restore_variables(sess, saver)
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
@@ -204,7 +203,7 @@ class SavedModelLoaderTest(test.TestCase):
loader = loader_impl.SavedModelLoader(path)
with self.test_session(graph=ops.Graph()) as sess:
- saver, _ = loader.load_graph(sess.graph, ["foo_graph"])
+ saver = loader.load_graph(sess.graph, ["foo_graph"])
self.assertFalse(variables._all_saveable_objects())
self.assertIsNotNone(saver)
@@ -213,18 +212,6 @@ class SavedModelLoaderTest(test.TestCase):
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
- def test_load_saved_model_graph_with_return_elements(self):
- """Ensure that the correct elements are returned."""
- loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
- graph = ops.Graph()
- _, ret = loader.load_graph(graph, ["foo_graph"],
- return_elements=["y:0", "x:0"])
-
- self.assertEqual(graph.get_tensor_by_name("y:0"), ret[0])
- self.assertEqual(graph.get_tensor_by_name("x:0"), ret[1])
-
- with self.assertRaisesRegexp(ValueError, "not found in graph"):
- loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
if __name__ == "__main__":
test.main()