diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-12-08 15:45:25 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-08 15:49:02 -0800 |
commit | 28807c5666c9f574ef415fed7b18b99ebed41ecc (patch) | |
tree | 15c7249d2925465202e21ac800a81a6df5bda8f2 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | b1c7d177e2aa9a4e3989caf7cfb21a5591c3832f (diff) |
Add Operation._remove_all_control_inputs and use in ControlFlowContext.
This allows while loop gradients to work with the C API. This change
also enables the C API for control flow tests.
PiperOrigin-RevId: 178438424
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 3a61d76f58..e1d3f9a7d4 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -131,6 +132,7 @@ def isum(s, maximum_iterations=None): return r_s +@test_util.with_c_api class ControlFlowTest(test.TestCase): def testRefIdentity(self): @@ -2648,6 +2650,7 @@ class ControlFlowTest(test.TestCase): 1) +@test_util.with_c_api class ControlFlowContextCheckTest(test.TestCase): def _getWhileTensor(self): @@ -2764,6 +2767,7 @@ class ControlFlowContextCheckTest(test.TestCase): lambda: constant_op.constant(0)) +@test_util.with_c_api class TupleTest(test.TestCase): def testTensors(self): @@ -2849,6 +2853,7 @@ class TupleTest(test.TestCase): self.assertEquals(1, var.eval()) +@test_util.with_c_api class AssertTest(test.TestCase): def testGuardedAssertDoesNotCopyWhenTrue(self): @@ -2886,6 +2891,7 @@ class AssertTest(test.TestCase): self.assertEqual([], guarded_memcpy_nodestat_names) +@test_util.with_c_api class WhileOpBenchmark(test.Benchmark): """Evaluate the performance of while_loop op.""" @@ -2999,6 +3005,7 @@ class WhileOpBenchmark(test.Benchmark): name="unroll_same_device", iters=iters, wall_time=duration) +@test_util.with_c_api class EagerTest(test.TestCase): def testCond(self): |