aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-02 13:18:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 13:22:56 -0700
commit8d12c635cc48e896da0bcac1cd568bd6381ca64e (patch)
treed651bbcfdd325e649c230c19424acc62c28de725 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (diff)
Support shape_invariants in while_v2. Note that this arg is temporary and may be replaced by automatic shape inference in TF 2.0 (or before).
Add a output_shapes attr to While op to allow output shapes to be different from the incoming loop_vars. PiperOrigin-RevId: 215446737
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index ae61be614e..655fece5ff 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1040,7 +1040,6 @@ class ControlFlowTest(test.TestCase):
result = r[3].eval()
self.assertAllEqual(42, result)
- @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhile_5(self):
with self.cached_session():
@@ -1116,7 +1115,6 @@ class ControlFlowTest(test.TestCase):
self._testWhile_Gpu_1(use_gpu=False)
self._testWhile_Gpu_1(use_gpu=True)
- @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShape(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1152,7 +1150,6 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual([10000], r.eval())
- @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShapeInference(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1366,6 +1363,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCondWithControl(self):
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
@@ -1477,6 +1475,7 @@ class ControlFlowTest(test.TestCase):
self._testCondWhile_3(use_gpu=False)
self._testCondWhile_3(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_1(self):
with self.cached_session():
@@ -1493,6 +1492,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [i])
self.assertAllEqual(10, r.eval())
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_2(self):
with self.cached_session():
@@ -1502,6 +1502,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertAllEqual(10, r.eval())
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_3(self):
with self.cached_session():
@@ -1696,7 +1697,7 @@ class ControlFlowTest(test.TestCase):
for i in xrange(10):
self.assertEqual([i], q.dequeue().eval())
- @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
+ @test_util.disable_control_flow_v2("b/117119329 (stack)")
def testWhileStack_1(self):
with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
@@ -1781,7 +1782,6 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
- @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Shape(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -2291,7 +2291,6 @@ class ControlFlowTest(test.TestCase):
r = sess.run(r, feed_dict={v: 2.0})
self.assertAllClose(1024.0, r)
- @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Concat(self):
with self.cached_session() as sess:
x = variable_scope.get_variable("x", initializer=[[1., 2.]])