diff options
Diffstat (limited to 'tensorflow/python/saved_model/loader_test.py')
-rw-r--r-- | tensorflow/python/saved_model/loader_test.py | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py index ce18859f6b..9a0b276a4b 100644 --- a/tensorflow/python/saved_model/loader_test.py +++ b/tensorflow/python/saved_model/loader_test.py @@ -111,7 +111,8 @@ class SavedModelLoaderTest(test.TestCase): def test_load_with_import_scope(self): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) with self.test_session(graph=ops.Graph()) as sess: - saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz") + saver, _ = loader.load_graph( + sess.graph, ["foo_graph"], import_scope="baz") # The default saver should not work when the import scope is set. with self.assertRaises(errors.NotFoundError): @@ -149,7 +150,7 @@ class SavedModelLoaderTest(test.TestCase): def test_run_init_op(self): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) graph = ops.Graph() - saver = loader.load_graph(graph, ["foo_graph"]) + saver, _ = loader.load_graph(graph, ["foo_graph"]) with self.test_session(graph=graph) as sess: loader.restore_variables(sess, saver) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) @@ -203,7 +204,7 @@ class SavedModelLoaderTest(test.TestCase): loader = loader_impl.SavedModelLoader(path) with self.test_session(graph=ops.Graph()) as sess: - saver = loader.load_graph(sess.graph, ["foo_graph"]) + saver, _ = loader.load_graph(sess.graph, ["foo_graph"]) self.assertFalse(variables._all_saveable_objects()) self.assertIsNotNone(saver) @@ -212,6 +213,18 @@ class SavedModelLoaderTest(test.TestCase): self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + def test_load_saved_model_graph_with_return_elements(self): + """Ensure that the correct elements are returned.""" + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + graph = ops.Graph() + _, ret = loader.load_graph(graph, ["foo_graph"], + return_elements=["y:0", "x:0"]) + + self.assertEqual(graph.get_tensor_by_name("y:0"), ret[0]) + self.assertEqual(graph.get_tensor_by_name("x:0"), ret[1]) + + with self.assertRaisesRegexp(ValueError, "not found in graph"): + loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"]) if __name__ == "__main__": test.main() |