diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-09-11 13:46:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 13:50:50 -0700 |
commit | da99f7ca018d4916447d7b984d9d65be1a9615a8 (patch) | |
tree | 3500e0b7b114e135254187c85732827abda2f20b /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 418c7258687166fc79a04f5a8c903c782a8ad295 (diff) |
Make control_flow_ops._ENABLE_COND_V2 public.
Note this is not part of the official public API, but we do allow
other modules to modify this value (e.g. in tests).
PiperOrigin-RevId: 212512883
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 72 |
1 files changed, 36 insertions, 36 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index eac97af4ed..bdf7e0e4a0 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -333,7 +333,7 @@ class ControlFlowTest(test.TestCase): res.eval(feed_dict={data: 1.0}) def testCondBool(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113296297") values = constant_op.constant(10) @@ -384,7 +384,7 @@ class ControlFlowTest(test.TestCase): sess.run(r, feed_dict={t: 3}) def testCondIndexedSlices(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113296180") with self.test_session(): @@ -402,7 +402,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(0, ind) def testCondSparseTensor(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113296161 (SparseTensors)") with self.test_session(): @@ -422,7 +422,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(r.values.get_shape(), (2,)) def testCondResource(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -438,7 +438,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval()) def testCondIndexedSlicesDifferentTypes(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113293074") with self.test_session(): @@ -484,14 +484,14 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(11, result) def testCond_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") self._testCond_1(use_gpu=False) self._testCond_1(use_gpu=True) def testCond_2(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -503,7 +503,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(9, result) def testCond_3(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -518,7 +518,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(12, result) def testCond_4(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113324949 (ref vars)") with self.test_session(): @@ -556,7 +556,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(4, count.eval()) def testCond_6(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -583,7 +583,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([11, 12], sess.run(r)) def testCondRef(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -599,7 +599,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([2.0], r.eval()) def testCondWithControl(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/79881896") with self.test_session() as sess: @@ -641,7 +641,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual([1.0], sess.run(merged_op.output)) def testCondSwitchIdentity(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") # Make sure the recv identity is not removed by optimization. @@ -658,7 +658,7 @@ class ControlFlowTest(test.TestCase): sess.run(r) def testCondRecvIdentity(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") # Make sure the switch identity is not removed by optimization. @@ -677,7 +677,7 @@ class ControlFlowTest(test.TestCase): sess.run(r) def testCondGrad_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113346829 (gpu failure)") graph = ops.Graph() @@ -706,7 +706,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3})) def testCondGrad_3(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/110550782 (gradient w.r.t external variable)") with self.test_session(): @@ -741,7 +741,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(1.0, result.eval()) def testCondGrad_Gather(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113327884") with self.test_session() as sess: @@ -916,7 +916,7 @@ class ControlFlowTest(test.TestCase): _ = gradients_impl.gradients(loop_with_maxiter, v) def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294340 (enable while_v2)") v = constant_op.constant(1.0) @@ -1375,7 +1375,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(10, sess.run(r, {b: True})) def testWhileCondWithControl(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") # Ensure that no control edges by an outer control dependency context are @@ -1392,7 +1392,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(0, sess.run(loop)) def testWhileCondWithControl_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113324949 (ref vars)") with self.test_session(): @@ -1417,7 +1417,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(65536.0, v.eval()) def testWhileCondExitControl(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294340 (enable while_v2)") with self.test_session(): @@ -1443,7 +1443,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(99, v.eval()) def testCondWhile_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -1456,7 +1456,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def testCondWhile_2(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -1469,7 +1469,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def _testCondWhile_3(self, use_gpu): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294340 (enable while_v2)") with self.test_session(use_gpu=use_gpu) as sess: @@ -1498,7 +1498,7 @@ class ControlFlowTest(test.TestCase): self._testCondWhile_3(use_gpu=True) def testWhileCond_1(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(): @@ -1516,7 +1516,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def testWhileCond_2(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(): @@ -1527,7 +1527,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(10, r.eval()) def testWhileCond_3(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(): @@ -1872,7 +1872,7 @@ class ControlFlowTest(test.TestCase): self._testWhileGrad_Mul(use_gpu=True, p_iters=10) def _testNestedWhileCondWhileGrad(self, use_gpu): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") with self.test_session(use_gpu=use_gpu): @@ -1913,7 +1913,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(216.0, r[0].eval()) def testWhileGradInCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/110550782 (gradient w.r.t external variable)") with self.test_session(): @@ -1964,7 +1964,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) def testCondGradInNestedWhiles(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113346829 (gpu failure)") def outer_body(i, x): @@ -2280,7 +2280,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(1024.0, r.eval()) def testWhileCondGrad_Simple(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113294377 (unknown shape)") self._testWhileCondGrad_Simple(use_gpu=False) @@ -2633,7 +2633,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(5.0, result.eval()) def testOneValueCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -2651,7 +2651,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual([2], i.eval(feed_dict={c: 0})) def testExampleCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/111124878 (don't return tuple)") with self.test_session(): @@ -2669,7 +2669,7 @@ class ControlFlowTest(test.TestCase): self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) def testCase(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") with self.test_session(): @@ -2724,7 +2724,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(r6.eval(), 0) def testCaseSideEffects(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/112477618 (Operation returned from cond)") with self.test_session() as sess: @@ -2762,7 +2762,7 @@ class ControlFlowTest(test.TestCase): self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1]) def testOneOpCond(self): - if control_flow_ops._ENABLE_COND_V2: + if control_flow_ops.ENABLE_COND_V2: return unittest.skip("b/113324949 (ref vars)") with self.test_session(): |