aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-02 09:45:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-02 11:05:58 -0700
commit350fd444ce1592c91c150c180dba6a4f2c9de136 (patch)
tree53f58792bd630ee9030cf10abb8d083e4ad496ed
parent87ef75cfb4f4d22b8e2e59b42c30d825f302d0b7 (diff)
Initialize resources in the prediction path before loading from a checkpoint.
Reload all saveable objects from the graph. Change: 137964298
-rw-r--r--tensorflow/contrib/learn/BUILD2
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py15
-rw-r--r--tensorflow/python/BUILD1
4 files changed, 24 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 62d7bb77c9..b93089c9cb 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -291,7 +291,9 @@ py_test(
deps = [
":learn",
"//tensorflow:tensorflow_py",
+ "//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:test_ops",
],
)
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index 0c5152b553..baee707a5f 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -40,6 +40,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
@@ -77,7 +78,8 @@ def get_summary_writer(logdir):
def _make_saver(graph, keep_checkpoint_max=5):
- vars_to_save = graph.get_collection(ops.GraphKeys.VARIABLES)
+ vars_to_save = (graph.get_collection(ops.GraphKeys.VARIABLES) +
+ graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
if vars_to_save:
return tf_saver.Saver(vars_to_save,
sharded=True,
@@ -846,9 +848,11 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
graph = contrib_ops.get_graph_from_inputs(output_dict.values())
-
with graph.as_default() as g:
with tf_session.Session('') as session:
+ session.run(
+ resources.initialize_resources(resources.shared_resources() +
+ resources.local_resources()))
if restore_checkpoint_path:
_restore_from_checkpoint(session, g, restore_checkpoint_path)
else:
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index 9a7306ad4a..c8c73d5de5 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -28,6 +28,8 @@ from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_ops
+from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
@@ -194,6 +196,19 @@ class GraphActionsTest(tf.test.TestCase):
pass
self.assertTrue(request_stop.called)
+ def test_run_feeds_iter_calls_resources_init(self):
+ with tf.Graph().as_default() as g:
+ in0, _, _ = self._build_inference_graph()
+ handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
+ resources.register_resource(
+ handle=handle,
+ create_op=test_ops.resource_create_op(handle),
+ is_initialized_op=test_ops.resource_initialized_op(handle))
+
+ for _ in learn.graph_actions.run_feeds_iter({'in0': in0},
+ feed_dicts=[{}]):
+ self.assertTrue(test_ops.resource_initialized_op(handle).eval())
+
def test_infer_different_default_graph(self):
with self.test_session():
self._assert_ckpt(self._output_dir, False)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index a32c76273d..1701aaf5cb 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -41,6 +41,7 @@ py_library(
":summary",
":training",
":ops",
+ ":test_ops",
"//tensorflow/python/debug:debug_py",
] + if_not_windows([
"//tensorflow/contrib:contrib_py",