diff options
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops_test.py')
-rw-r--r-- | tensorflow/python/ops/control_flow_ops_test.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 43fe045bcb..153548ae92 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -958,6 +958,28 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase): # Expect a tuple since that is what the body returns. self.assertEqual(self.evaluate(r), (10,)) + def testWhileLoopSameReturnShape_False(self): + i = constant_op.constant(0) + c = lambda i, _: math_ops.less(i, 10) + + # Body returns a [tensor, []] + b = lambda i, _: [math_ops.add(i, 1), []] + + # Should only return the tensor. + r = control_flow_ops.while_loop(c, b, [i, []]) + self.assertEqual(self.evaluate(r), 10) + + def testWhileLoopSameReturnShape_True(self): + i = constant_op.constant(0) + c = lambda i, _: math_ops.less(i, 10) + + # Body returns a [tensor, []] + b = lambda i, _: [math_ops.add(i, 1), []] + + # Should only return the original structure. + r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True) + self.assertEqual(self.evaluate(r), [10, []]) + if __name__ == "__main__": googletest.main() |