diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/functional_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/functional_ops_test.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 3ddb5e06c9..e39daf1371 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import iterator_ops @@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(Run(sess, 20.), 210.) self.assertAllEqual(Run(sess, 100.), 5050.) + def testWhileLowering(self): + + def Run(n, fetch_by_name): + for use_gpu in (True, False): + with ops.Graph().as_default() as g: + + @function.Defun(*[dtypes.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + # outputs: [0, n*(n+1)/2] + outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while") + + # `outputs` is the list of output tensors of the While op. We + # arbitrarily choose the 0th tensor to get the While op and set the + # lowering attribute on it. + outputs[0].op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + if not fetch_by_name: + fetch = outputs[1] + else: + fetch = "my_while:1" + with self.test_session(graph=g, use_gpu=use_gpu) as sess: + return sess.run(fetch) + + self.assertAllEqual(Run(20., False), 210.) + self.assertAllEqual(Run(20., True), 210.) + self.assertAllEqual(Run(100., False), 5050.) + self.assertAllEqual(Run(100., True), 5050.) + def testWhileError(self): for use_gpu in (True, False): with ops.Graph().as_default() as g: |