aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-28 12:46:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 12:53:29 -0700
commit5e66d25666aad9fa76ed8cc0d2b162db76ea0cc8 (patch)
treea5b810c506ee8eb61b707e890ae41b71ac4cb8bd /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parente00954e8626c74b263b90527e0c020cfd64136b2 (diff)
Add flag for enabling while_v2.
Add a single test flag for enabling v2 control flow in tests since we do not plan to support v2 ops with legacy control flow. We have 2 test decorators now: @with_control_flow_v2: Enables all tests in a class to run with v2 control flow. @disable_control_flow_v2: Disables a test function from running in v2. I have removed the skiptests to avoid setup/teardown overheads. Enable tests in control_flow_ops_py_test that run with control_flow_v2. PiperOrigin-RevId: 214980108
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py180
1 files changed, 94 insertions, 86 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 083de84775..d91a848e01 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -63,6 +62,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2 # pylint: disable=unused-import
# pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad
# pylint: enable=unused-import
@@ -125,7 +125,7 @@ def isum(s, maximum_iterations=None):
return r_s
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
@@ -332,10 +332,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
+ @test_util.disable_control_flow_v2("b/113294340")
def testCondBool(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296297")
-
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
fn2 = lambda: math_ops.subtract(values, 1)
@@ -366,6 +364,7 @@ class ControlFlowTest(test.TestCase):
"has been marked as not fetchable"):
sess.run(t, feed_dict={x: 3})
+ @test_util.disable_control_flow_v2("Not relevant")
def testFeedable(self):
with self.cached_session() as sess:
c = constant_op.constant(2)
@@ -383,10 +382,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "may not be fed"):
sess.run(r, feed_dict={t: 3})
+ @test_util.disable_control_flow_v2("b/113296180 (IndexedSlices)")
def testCondIndexedSlices(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296180")
-
with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
@@ -401,10 +398,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, val)
self.assertAllEqual(0, ind)
+ @test_util.disable_control_flow_v2("b/113296161 (SparseTensors)")
def testCondSparseTensor(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296161 (SparseTensors)")
-
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
@@ -435,10 +430,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
+ @test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113293074")
-
with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
@@ -510,10 +503,8 @@ class ControlFlowTest(test.TestCase):
result = r.eval()
self.assertAllEqual(12, result)
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testCond_4(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
@@ -587,10 +578,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
self.assertAllEqual([2.0], r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/79881896")
-
with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -629,10 +618,9 @@ class ControlFlowTest(test.TestCase):
merged_op = control_flow_ops.merge([assign_v, orig_v])
self.assertAllEqual([1.0], sess.run(merged_op.output))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondSwitchIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
pred = constant_op.constant(True)
@@ -646,10 +634,9 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondRecvIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
with ops.device(test.gpu_device_name()):
@@ -665,10 +652,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
-
graph = ops.Graph()
with graph.as_default():
x = constant_op.constant(10.0, name="x")
@@ -694,10 +679,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+ @test_util.disable_control_flow_v2(
+ "b/110550782 (gradient w.r.t external variable)")
def testCondGrad_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
-
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
@@ -729,10 +713,8 @@ class ControlFlowTest(test.TestCase):
result = gradients_impl.gradients(z, x)[0]
self.assertEqual(1.0, result.eval())
+ @test_util.disable_control_flow_v2("b/113327884")
def testCondGrad_Gather(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113327884")
-
with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -756,6 +738,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(dense_gv, [0.0, 2.0])
# Microbenchmark: 256,000 iterations/s.
+ @test_util.disable_control_flow_v2("b/116630618 (Times out)")
def testWhile_1(self):
with self.cached_session():
n = constant_op.constant(0)
@@ -764,6 +747,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependencies(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -779,6 +763,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(result.eval(), 2)
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependenciesNoInput(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -794,6 +779,7 @@ class ControlFlowTest(test.TestCase):
result.eval()
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefs_1(self):
with self.cached_session() as sess:
x = variables.VariableV1(0)._ref() # pylint: disable=protected-access
@@ -824,18 +810,22 @@ class ControlFlowTest(test.TestCase):
r = isum(s)
self.assertAllEqual(45, r.eval())
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testWhileWithMaximumIterations(self):
with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithMaximumIterationsAndSingleArgument(self):
with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nested), b/115920078 (gradients)")
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -861,6 +851,7 @@ class ControlFlowTest(test.TestCase):
# Should execute without issue.
self.assertEqual(3, self.evaluate(loop_execute))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while_loop)")
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -904,10 +895,8 @@ class ControlFlowTest(test.TestCase):
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
v = constant_op.constant(1.0)
def create_while_loop():
@@ -939,6 +928,8 @@ class ControlFlowTest(test.TestCase):
r"while loop context '' \(currently defined in 'cond/.+'\)"):
_ = gradients_impl.gradients(loop, v)
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nesting), b/115776323 (max_iters)")
def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
v = constant_op.constant(1.0)
@@ -1048,6 +1039,7 @@ class ControlFlowTest(test.TestCase):
result = r[3].eval()
self.assertAllEqual(42, result)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhile_5(self):
with self.cached_session():
@@ -1072,6 +1064,7 @@ class ControlFlowTest(test.TestCase):
result = r[2].eval()
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
def testBufferForwarding(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1122,6 +1115,7 @@ class ControlFlowTest(test.TestCase):
self._testWhile_Gpu_1(use_gpu=False)
self._testWhile_Gpu_1(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShape(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1139,6 +1133,7 @@ class ControlFlowTest(test.TestCase):
r = r[1] * array_ops.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Scalar(self):
with self.cached_session():
n = 0
@@ -1147,6 +1142,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Vector(self):
with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
@@ -1155,6 +1151,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual([10000], r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShapeInference(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1169,7 +1166,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(
c, b, [i, m],
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
- self.assertTrue(r[1].get_shape()[0].value is None)
+ self.assertIsNone(r[1].get_shape()[0].value)
self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
with self.assertRaisesRegexp(
@@ -1180,6 +1177,7 @@ class ControlFlowTest(test.TestCase):
r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileShapeInferenceSparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -1211,6 +1209,7 @@ class ControlFlowTest(test.TestCase):
c, b, [i, x],
[i.get_shape(), tensor_shape.TensorShape([5])])
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileShapeInferenceIndexedSlices(self):
with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
@@ -1265,6 +1264,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertEqual(225, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_1(self):
self._testNestedWhile_1(use_gpu=False)
self._testNestedWhile_1(use_gpu=True)
@@ -1297,6 +1297,7 @@ class ControlFlowTest(test.TestCase):
outer_c, outer_b, [s0], parallel_iterations=1)
self.assertEqual(1048576.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_2(self):
self._testNestedWhile_2(use_gpu=False)
self._testNestedWhile_2(use_gpu=True)
@@ -1350,6 +1351,7 @@ class ControlFlowTest(test.TestCase):
lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/79881896 (control_deps)")
def testWhileWithControl_5(self):
with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
@@ -1364,9 +1366,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
with self.cached_session() as sess:
@@ -1380,10 +1379,8 @@ class ControlFlowTest(test.TestCase):
(constant_op.constant(5),))
self.assertEqual(0, sess.run(loop))
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondWithControl_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
@@ -1405,9 +1402,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(4, r.eval())
self.assertAllClose(65536.0, v.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondExitControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
v = variables.Variable(1)
@@ -1432,8 +1428,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1445,8 +1439,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1458,9 +1450,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
n = constant_op.constant(0.0)
@@ -1477,18 +1466,17 @@ class ControlFlowTest(test.TestCase):
lambda: control_flow_ops.while_loop(c, b, [n]),
lambda: math_ops.multiply(n, 2.0))
r1 = gradients_impl.gradients(r, [n])
- self.assertEqual(10, sess.run(r, {p: True}))
+ self.assertEqual(10., sess.run(r, {p: True}))
self.assertEqual([1.0], sess.run(r1, {p: True}))
self.assertEqual(0.0, sess.run(r, {p: False}))
self.assertEqual([2.0], sess.run(r1, {p: False}))
+ @test_util.disable_control_flow_v2("b/116743589")
def testCondWhile_3(self):
self._testCondWhile_3(use_gpu=False)
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
@@ -1505,8 +1493,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1516,8 +1502,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1532,6 +1516,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
# NOTE: It is ok to have parallel_iterations > 1
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_1(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1554,6 +1539,7 @@ class ControlFlowTest(test.TestCase):
result = select.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_2(self):
with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
@@ -1580,6 +1566,7 @@ class ControlFlowTest(test.TestCase):
result2 = select2.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_3(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1601,7 +1588,7 @@ class ControlFlowTest(test.TestCase):
result = r[1].eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
- # b/24814703
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_4(self):
with self.cached_session():
var_a = variables.Variable(0, name="a")
@@ -1629,7 +1616,7 @@ class ControlFlowTest(test.TestCase):
lpa.eval() # Run the loop
self.assertEqual(10, var_b.eval())
- # b/24736492
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_5(self):
with self.cached_session():
# Create some variables.
@@ -1659,7 +1646,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
self.assertEqual(10, var_b.eval())
- # b/24814668
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_6(self):
with self.cached_session():
# Create some variables.
@@ -1689,6 +1676,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(55, var_b.eval())
self.assertEqual(10, var_a.eval())
+ @test_util.disable_control_flow_v2("b/116742472 (resource accumulator)")
def testWhileQueue_1(self):
with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
@@ -1707,6 +1695,7 @@ class ControlFlowTest(test.TestCase):
for i in xrange(10):
self.assertEqual([i], q.dequeue().eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileStack_1(self):
with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
@@ -1775,6 +1764,7 @@ class ControlFlowTest(test.TestCase):
with self.session(graph=graph) as sess:
self.assertAllClose(1024.0, sess.run(r))
+ @test_util.disable_control_flow_v2("b/116351701 (colocation)")
def testWhileGrad_ColocateGradients(self):
self._testWhileGrad_ColocateGradients(colocate=False)
self._testWhileGrad_ColocateGradients(colocate=True)
@@ -1790,6 +1780,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Shape(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -1861,8 +1852,6 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1885,10 +1874,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhileCondWhileGrad(self):
self._testNestedWhileCondWhileGrad(use_gpu=False)
self._testNestedWhileCondWhileGrad(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116823782")
def testWhileGrad_Variable(self):
with self.cached_session():
a = variables.Variable(3.0)
@@ -1902,8 +1893,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1919,6 +1908,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116340060")
def testGradInWhileWrtInitialLoopVal(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
@@ -1936,6 +1926,7 @@ class ControlFlowTest(test.TestCase):
"loop invariants or wrt the input parameters to the loop body."):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testWhileGradInWhile(self):
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1952,9 +1943,8 @@ class ControlFlowTest(test.TestCase):
[tensor_shape.unknown_shape()])
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testCondGradInNestedWhiles(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
@@ -1972,6 +1962,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
+ @test_util.disable_control_flow_v2("b/116255781 (flat_args)")
def testWhile_NestedInput(self):
with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
@@ -1999,6 +1990,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
sess.run(r_flattened))
+ @test_util.disable_control_flow_v2("b/116255781(flat_args)")
def testWhile_NestedBadArityFails(self):
with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
@@ -2057,6 +2049,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients([rx], x)
self.assertAllClose(1024.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
def testWhileGrad_NoGradient(self):
with self.cached_session():
v = constant_op.constant(2.0, name="v")
@@ -2067,6 +2060,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)
self.assertAllClose(1.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGrad_NoDependency(self):
with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
@@ -2180,10 +2174,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(8.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_Simple(self):
self._testNestedWhileGrad_Simple(use_gpu=False)
self._testNestedWhileGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_SerialInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2207,6 +2203,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(256.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_ParallelInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2230,6 +2227,8 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2(
+ "Nested loops and TensorArrays not supported")
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
@@ -2268,13 +2267,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_Simple(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_UnknownShape(self):
with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
@@ -2292,6 +2290,7 @@ class ControlFlowTest(test.TestCase):
r = sess.run(r, feed_dict={v: 2.0})
self.assertAllClose(1024.0, r)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Concat(self):
with self.cached_session() as sess:
x = variable_scope.get_variable("x", initializer=[[1., 2.]])
@@ -2315,6 +2314,7 @@ class ControlFlowTest(test.TestCase):
sess.run(op)
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefsWithGradients_1(self):
with self.cached_session() as sess:
x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access
@@ -2343,6 +2343,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad)
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileGrad_IndexedSlices(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2364,6 +2365,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileGrad_SparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2386,6 +2388,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testCallGradInLoop(self):
with self.cached_session() as sess:
i0 = constant_op.constant(0)
@@ -2405,6 +2408,8 @@ class ControlFlowTest(test.TestCase):
c, b, [i0, constant_op.constant(0.0)])
self.assertAllClose(600.0, sess.run(output_grad)[1])
+ @test_util.disable_control_flow_v2(
+ "b/116255781 (flat_args), b/115660901 (TensorArray)")
def testWhileAndTensorArray(self):
with self.cached_session() as sess:
param = constant_op.constant(2.0)
@@ -2509,6 +2514,7 @@ class ControlFlowTest(test.TestCase):
all_ops = x.graph.get_operations()
self.assertFalse(any([name in op.name for op in all_ops]))
+ @test_util.disable_control_flow_v2("b/116255781 (flat args)")
def testWhileGradGradFail(self):
theta = variables.Variable(initial_value=1.)
@@ -2538,6 +2544,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath1(self):
q = variables.Variable([7., 8.])
@@ -2555,6 +2562,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath2(self):
q = variables.Variable([7., 8.])
@@ -2572,6 +2580,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testIssue16504(self):
c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
w = variables.Variable(
@@ -2595,6 +2604,7 @@ class ControlFlowTest(test.TestCase):
grad, = gradients_impl.gradients(w, c)
self.assertIsNotNone(grad)
+ @test_util.disable_control_flow_v2("b/116270461 (resource)")
def testStopGradMultiFlows(self):
with self.cached_session():
@@ -2653,10 +2663,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCase(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
@@ -2708,10 +2717,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCaseSideEffects(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
@@ -2746,10 +2754,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, r0.eval())
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testOneOpCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
@@ -3031,9 +3037,11 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, x)[0]
self.assertEqual(r.eval(), 524288.0)
- self.assertEqual(
- len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
- 1)
+ # while_v2 does not have stacks.
+ if not control_flow_ops.ENABLE_WHILE_V2:
+ self.assertEqual(
+ len([op for op in x.graph.get_operations() if op.type == "StackV2"
+ ]), 1)
class ControlFlowContextCheckTest(test.TestCase):
@@ -3393,7 +3401,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class EagerTest(test.TestCase):
def testCond(self):