diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-02 13:18:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 13:22:56 -0700 |
commit | 8d12c635cc48e896da0bcac1cd568bd6381ca64e (patch) | |
tree | d651bbcfdd325e649c230c19424acc62c28de725 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (diff) |
Support shape_invariants in while_v2. Note that this arg is temporary and may be replaced by automatic shape inference in TF 2.0 (or before).
Add a output_shapes attr to While op to allow output shapes to be different from the incoming loop_vars.
PiperOrigin-RevId: 215446737
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index ae61be614e..655fece5ff 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1040,7 +1040,6 @@ class ControlFlowTest(test.TestCase): result = r[3].eval() self.assertAllEqual(42, result) - @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)") def testWhile_5(self): with self.cached_session(): @@ -1116,7 +1115,6 @@ class ControlFlowTest(test.TestCase): self._testWhile_Gpu_1(use_gpu=False) self._testWhile_Gpu_1(use_gpu=True) - @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)") def testWhileShape(self): with self.cached_session(): i = constant_op.constant(0) @@ -1152,7 +1150,6 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) self.assertEqual([10000], r.eval()) - @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)") def testWhileShapeInference(self): with self.cached_session(): i = constant_op.constant(0) @@ -1366,6 +1363,7 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0]) self.assertEqual(10, sess.run(r, {b: True})) + @test_util.disable_control_flow_v2("b/116134862 (cond output shape)") def testWhileCondWithControl(self): # Ensure that no control edges by an outer control dependency context are # added to nodes inside cond/while contexts. @@ -1477,6 +1475,7 @@ class ControlFlowTest(test.TestCase): self._testCondWhile_3(use_gpu=False) self._testCondWhile_3(use_gpu=True) + @test_util.disable_control_flow_v2("b/116134862 (cond output shape)") def testWhileCond_1(self): with self.cached_session(): @@ -1493,6 +1492,7 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [i]) self.assertAllEqual(10, r.eval()) + @test_util.disable_control_flow_v2("b/116134862 (cond output shape)") def testWhileCond_2(self): with self.cached_session(): @@ -1502,6 +1502,7 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [n]) self.assertAllEqual(10, r.eval()) + @test_util.disable_control_flow_v2("b/116134862 (cond output shape)") def testWhileCond_3(self): with self.cached_session(): @@ -1696,7 +1697,7 @@ class ControlFlowTest(test.TestCase): for i in xrange(10): self.assertEqual([i], q.dequeue().eval()) - @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)") + @test_util.disable_control_flow_v2("b/117119329 (stack)") def testWhileStack_1(self): with self.cached_session(): s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo") @@ -1781,7 +1782,6 @@ class ControlFlowTest(test.TestCase): r = gradients_impl.gradients(r, v)[0] self.assertAllClose(1024.0, r.eval()) - @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)") def testWhileGrad_Shape(self): with self.cached_session(): x = array_ops.placeholder(dtypes.float32, shape=[None]) @@ -2291,7 +2291,6 @@ class ControlFlowTest(test.TestCase): r = sess.run(r, feed_dict={v: 2.0}) self.assertAllClose(1024.0, r) - @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)") def testWhileGrad_Concat(self): with self.cached_session() as sess: x = variable_scope.get_variable("x", initializer=[[1., 2.]]) |