diff options
author | 2016-11-02 09:45:30 -0800 | |
---|---|---|
committer | 2016-11-02 11:05:58 -0700 | |
commit | 350fd444ce1592c91c150c180dba6a4f2c9de136 (patch) | |
tree | 53f58792bd630ee9030cf10abb8d083e4ad496ed | |
parent | 87ef75cfb4f4d22b8e2e59b42c30d825f302d0b7 (diff) |
Initialize resources in the prediction path before loading from a checkpoint.
Reload all saveable objects from the graph.
Change: 137964298
-rw-r--r-- | tensorflow/contrib/learn/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/graph_actions.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/graph_actions_test.py | 15 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 |
4 files changed, 24 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 62d7bb77c9..b93089c9cb 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -291,7 +291,9 @@ py_test( deps = [ ":learn", "//tensorflow:tensorflow_py", + "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:test_ops", ], ) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 0c5152b553..baee707a5f 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -40,6 +40,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks @@ -77,7 +78,8 @@ def get_summary_writer(logdir): def _make_saver(graph, keep_checkpoint_max=5): - vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES) + vars_to_save = (graph.get_collection(ops.GraphKeys.VARIABLES) + + graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS)) if vars_to_save: return tf_saver.Saver(vars_to_save, sharded=True, @@ -846,9 +848,11 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): raise ValueError('feed_dicts is invalid: %s.' % feed_dicts) graph = contrib_ops.get_graph_from_inputs(output_dict.values()) - with graph.as_default() as g: with tf_session.Session('') as session: + session.run( + resources.initialize_resources(resources.shared_resources() + + resources.local_resources())) if restore_checkpoint_path: _restore_from_checkpoint(session, g, restore_checkpoint_path) else: diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index 9a7306ad4a..c8c73d5de5 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -28,6 +28,8 @@ from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import test_ops +from tensorflow.python.ops import resources from tensorflow.python.ops import variables @@ -194,6 +196,19 @@ class GraphActionsTest(tf.test.TestCase): pass self.assertTrue(request_stop.called) + def test_run_feeds_iter_calls_resources_init(self): + with tf.Graph().as_default() as g: + in0, _, _ = self._build_inference_graph() + handle = test_ops.stub_resource_handle_op(container='a', shared_name='b') + resources.register_resource( + handle=handle, + create_op=test_ops.resource_create_op(handle), + is_initialized_op=test_ops.resource_initialized_op(handle)) + + for _ in learn.graph_actions.run_feeds_iter({'in0': in0}, + feed_dicts=[{}]): + self.assertTrue(test_ops.resource_initialized_op(handle).eval()) + def test_infer_different_default_graph(self): with self.test_session(): self._assert_ckpt(self._output_dir, False) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a32c76273d..1701aaf5cb 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -41,6 +41,7 @@ py_library( ":summary", ":training", ":ops", + ":test_ops", "//tensorflow/python/debug:debug_py", ] + if_not_windows([ "//tensorflow/contrib:contrib_py", |