aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-09-11 13:46:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 13:50:50 -0700
commitda99f7ca018d4916447d7b984d9d65be1a9615a8 (patch)
tree3500e0b7b114e135254187c85732827abda2f20b /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent418c7258687166fc79a04f5a8c903c782a8ad295 (diff)
Make control_flow_ops._ENABLE_COND_V2 public.
Note this is not part of the official public API, but we do allow other modules to modify this value (e.g. in tests). PiperOrigin-RevId: 212512883
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py72
1 files changed, 36 insertions, 36 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index eac97af4ed..bdf7e0e4a0 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -333,7 +333,7 @@ class ControlFlowTest(test.TestCase):
res.eval(feed_dict={data: 1.0})
def testCondBool(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296297")
values = constant_op.constant(10)
@@ -384,7 +384,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r, feed_dict={t: 3})
def testCondIndexedSlices(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296180")
with self.test_session():
@@ -402,7 +402,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(0, ind)
def testCondSparseTensor(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296161 (SparseTensors)")
with self.test_session():
@@ -422,7 +422,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r.values.get_shape(), (2,))
def testCondResource(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -438,7 +438,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113293074")
with self.test_session():
@@ -484,14 +484,14 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, result)
def testCond_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
self._testCond_1(use_gpu=False)
self._testCond_1(use_gpu=True)
def testCond_2(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -503,7 +503,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(9, result)
def testCond_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -518,7 +518,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, result)
def testCond_4(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
with self.test_session():
@@ -556,7 +556,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(4, count.eval())
def testCond_6(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -583,7 +583,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([11, 12], sess.run(r))
def testCondRef(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -599,7 +599,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([2.0], r.eval())
def testCondWithControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/79881896")
with self.test_session() as sess:
@@ -641,7 +641,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([1.0], sess.run(merged_op.output))
def testCondSwitchIdentity(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the recv identity is not removed by optimization.
@@ -658,7 +658,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondRecvIdentity(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the switch identity is not removed by optimization.
@@ -677,7 +677,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondGrad_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113346829 (gpu failure)")
graph = ops.Graph()
@@ -706,7 +706,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
def testCondGrad_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.test_session():
@@ -741,7 +741,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, result.eval())
def testCondGrad_Gather(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113327884")
with self.test_session() as sess:
@@ -916,7 +916,7 @@ class ControlFlowTest(test.TestCase):
_ = gradients_impl.gradients(loop_with_maxiter, v)
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
v = constant_op.constant(1.0)
@@ -1375,7 +1375,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
# Ensure that no control edges by an outer control dependency context are
@@ -1392,7 +1392,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, sess.run(loop))
def testWhileCondWithControl_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
with self.test_session():
@@ -1417,7 +1417,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(65536.0, v.eval())
def testWhileCondExitControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
with self.test_session():
@@ -1443,7 +1443,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -1456,7 +1456,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -1469,7 +1469,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
with self.test_session(use_gpu=use_gpu) as sess:
@@ -1498,7 +1498,7 @@ class ControlFlowTest(test.TestCase):
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
with self.test_session():
@@ -1516,7 +1516,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
with self.test_session():
@@ -1527,7 +1527,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
with self.test_session():
@@ -1872,7 +1872,7 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
@@ -1913,7 +1913,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.test_session():
@@ -1964,7 +1964,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
def testCondGradInNestedWhiles(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
@@ -2280,7 +2280,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r.eval())
def testWhileCondGrad_Simple(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
self._testWhileCondGrad_Simple(use_gpu=False)
@@ -2633,7 +2633,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5.0, result.eval())
def testOneValueCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -2651,7 +2651,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([2], i.eval(feed_dict={c: 0}))
def testExampleCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
@@ -2669,7 +2669,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
def testCase(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
with self.test_session():
@@ -2724,7 +2724,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
def testCaseSideEffects(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
with self.test_session() as sess:
@@ -2762,7 +2762,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
def testOneOpCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
with self.test_session():