diff options
author | Katherine Wu <kathywu@google.com> | 2018-07-20 15:45:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 15:51:41 -0700 |
commit | 6c528feaf820bdde820833ad24e05167adb5daa7 (patch) | |
tree | cdffac07b9e343e03958b734ac9553102bbd4ccf /tensorflow/python/saved_model | |
parent | 5e876a8c25819070d78aa96595943afa207a6671 (diff) |
Automated rollback of commit 8257891f378027a1a7c0403ba6ba0aeb313496a0
PiperOrigin-RevId: 205466000
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/loader_impl.py | 13 | ||||
-rw-r--r-- | tensorflow/python/saved_model/loader_test.py | 19 |
2 files changed, 8 insertions, 24 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 685a913f9c..e5f649fdab 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -284,15 +284,12 @@ class SavedModelLoader(object): **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. Returns: - A tuple of - * Saver defined by the MetaGraph, which can be used to restore the - variable values. - * List of `Operation`/`Tensor` objects returned from - `tf.import_graph_def` (may be `None`). + Saver defined by the MetaGraph, which can be used to restore the variable + values. """ meta_graph_def = self.get_meta_graph_def_from_tags(tags) with graph.as_default(): - return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access + return tf_saver.import_meta_graph( meta_graph_def, import_scope=import_scope, **saver_kwargs) def restore_variables(self, sess, saver, import_scope=None): @@ -364,8 +361,8 @@ class SavedModelLoader(object): `MetagraphDef` proto of the graph that was loaded. """ with sess.graph.as_default(): - saver, _ = self.load_graph(sess.graph, tags, import_scope, - **saver_kwargs) + saver = self.load_graph(sess.graph, tags, import_scope, + **saver_kwargs) self.restore_variables(sess, saver, import_scope) self.run_init_ops(sess, tags, import_scope) return self.get_meta_graph_def_from_tags(tags) diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py index 9a0b276a4b..ce18859f6b 100644 --- a/tensorflow/python/saved_model/loader_test.py +++ b/tensorflow/python/saved_model/loader_test.py @@ -111,8 +111,7 @@ class SavedModelLoaderTest(test.TestCase): def test_load_with_import_scope(self): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) with self.test_session(graph=ops.Graph()) as sess: - saver, _ = loader.load_graph( - sess.graph, ["foo_graph"], import_scope="baz") + saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz") # The default saver should not work when the import scope is set. with self.assertRaises(errors.NotFoundError): @@ -150,7 +149,7 @@ class SavedModelLoaderTest(test.TestCase): def test_run_init_op(self): loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) graph = ops.Graph() - saver, _ = loader.load_graph(graph, ["foo_graph"]) + saver = loader.load_graph(graph, ["foo_graph"]) with self.test_session(graph=graph) as sess: loader.restore_variables(sess, saver) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) @@ -204,7 +203,7 @@ class SavedModelLoaderTest(test.TestCase): loader = loader_impl.SavedModelLoader(path) with self.test_session(graph=ops.Graph()) as sess: - saver, _ = loader.load_graph(sess.graph, ["foo_graph"]) + saver = loader.load_graph(sess.graph, ["foo_graph"]) self.assertFalse(variables._all_saveable_objects()) self.assertIsNotNone(saver) @@ -213,18 +212,6 @@ class SavedModelLoaderTest(test.TestCase): self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) - def test_load_saved_model_graph_with_return_elements(self): - """Ensure that the correct elements are returned.""" - loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) - graph = ops.Graph() - _, ret = loader.load_graph(graph, ["foo_graph"], - return_elements=["y:0", "x:0"]) - - self.assertEqual(graph.get_tensor_by_name("y:0"), ret[0]) - self.assertEqual(graph.get_tensor_by_name("x:0"), ret[1]) - - with self.assertRaisesRegexp(ValueError, "not found in graph"): - loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"]) if __name__ == "__main__": test.main() |