aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model/loader_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/saved_model/loader_test.py')
-rw-r--r--tensorflow/python/saved_model/loader_test.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index ce18859f6b..9a0b276a4b 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -111,7 +111,8 @@ class SavedModelLoaderTest(test.TestCase):
def test_load_with_import_scope(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.test_session(graph=ops.Graph()) as sess:
- saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz")
+ saver, _ = loader.load_graph(
+ sess.graph, ["foo_graph"], import_scope="baz")
# The default saver should not work when the import scope is set.
with self.assertRaises(errors.NotFoundError):
@@ -149,7 +150,7 @@ class SavedModelLoaderTest(test.TestCase):
def test_run_init_op(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
graph = ops.Graph()
- saver = loader.load_graph(graph, ["foo_graph"])
+ saver, _ = loader.load_graph(graph, ["foo_graph"])
with self.test_session(graph=graph) as sess:
loader.restore_variables(sess, saver)
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
@@ -203,7 +204,7 @@ class SavedModelLoaderTest(test.TestCase):
loader = loader_impl.SavedModelLoader(path)
with self.test_session(graph=ops.Graph()) as sess:
- saver = loader.load_graph(sess.graph, ["foo_graph"])
+ saver, _ = loader.load_graph(sess.graph, ["foo_graph"])
self.assertFalse(variables._all_saveable_objects())
self.assertIsNotNone(saver)
@@ -212,6 +213,18 @@ class SavedModelLoaderTest(test.TestCase):
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
+ def test_load_saved_model_graph_with_return_elements(self):
+ """Ensure that the correct elements are returned."""
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ graph = ops.Graph()
+ _, ret = loader.load_graph(graph, ["foo_graph"],
+ return_elements=["y:0", "x:0"])
+
+ self.assertEqual(graph.get_tensor_by_name("y:0"), ret[0])
+ self.assertEqual(graph.get_tensor_by_name("x:0"), ret[1])
+
+ with self.assertRaisesRegexp(ValueError, "not found in graph"):
+ loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
if __name__ == "__main__":
test.main()