aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2017-11-22 11:40:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-22 11:43:55 -0800
commit791ef8383d165c116f4c5fc3fda12ebc7eb07edf (patch)
tree004c70a7292067ee45e01f5dbb2c56371482a065 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parentcf245240ca90e6b552415f720342ae1acd326590 (diff)
python testing for is_feedable and is_fetchable nodes in the graph
PiperOrigin-RevId: 176682768
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index fc125daf38..1b7f9b110c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -352,14 +352,20 @@ class ControlFlowTest(test.TestCase):
grad = gradients_impl.gradients(y, [v])
self.assertAllEqual([None], grad)
- def testFetchables(self):
+ def testFetchable(self):
with self.test_session() as sess:
x = array_ops.placeholder(dtypes.float32)
control_flow_ops.cond(
constant_op.constant(True), lambda: x + 2, lambda: x + 0)
- tensor_names = all_fetchables()
- for name in tensor_names:
- sess.run(name, feed_dict={x: 3})
+ graph = ops.get_default_graph()
+ for op in graph.get_operations():
+ for t in op.inputs:
+ if graph.is_fetchable(t.op):
+ sess.run(t, feed_dict={x: 3})
+ else:
+ with self.assertRaisesRegexp(ValueError,
+ "has been marked as not fetchable"):
+ sess.run(t, feed_dict={x: 3})
def testFeedable(self):
with self.test_session() as sess: