aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-08 15:45:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-08 15:49:02 -0800
commit28807c5666c9f574ef415fed7b18b99ebed41ecc (patch)
tree15c7249d2925465202e21ac800a81a6df5bda8f2 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parentb1c7d177e2aa9a4e3989caf7cfb21a5591c3832f (diff)
Add Operation._remove_all_control_inputs and use in ControlFlowContext.
This allows while loop gradients to work with the C API. This change also enables the C API for control flow tests. PiperOrigin-RevId: 178438424
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 3a61d76f58..e1d3f9a7d4 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -131,6 +132,7 @@ def isum(s, maximum_iterations=None):
return r_s
+@test_util.with_c_api
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
@@ -2648,6 +2650,7 @@ class ControlFlowTest(test.TestCase):
1)
+@test_util.with_c_api
class ControlFlowContextCheckTest(test.TestCase):
def _getWhileTensor(self):
@@ -2764,6 +2767,7 @@ class ControlFlowContextCheckTest(test.TestCase):
lambda: constant_op.constant(0))
+@test_util.with_c_api
class TupleTest(test.TestCase):
def testTensors(self):
@@ -2849,6 +2853,7 @@ class TupleTest(test.TestCase):
self.assertEquals(1, var.eval())
+@test_util.with_c_api
class AssertTest(test.TestCase):
def testGuardedAssertDoesNotCopyWhenTrue(self):
@@ -2886,6 +2891,7 @@ class AssertTest(test.TestCase):
self.assertEqual([], guarded_memcpy_nodestat_names)
+@test_util.with_c_api
class WhileOpBenchmark(test.Benchmark):
"""Evaluate the performance of while_loop op."""
@@ -2999,6 +3005,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
+@test_util.with_c_api
class EagerTest(test.TestCase):
def testCond(self):