aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/functional_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/functional_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 3ddb5e06c9..e39daf1371 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import iterator_ops
@@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
+ def testWhileLowering(self):
+
+ def Run(n, fetch_by_name):
+ for use_gpu in (True, False):
+ with ops.Graph().as_default() as g:
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ # outputs: [0, n*(n+1)/2]
+ outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
+
+ # `outputs` is the list of output tensors of the While op. We
+ # arbitrarily choose the 0th tensor to get the While op and set the
+ # lowering attribute on it.
+ outputs[0].op._set_attr("_lower_using_switch_merge",
+ attr_value_pb2.AttrValue(b=True))
+ if not fetch_by_name:
+ fetch = outputs[1]
+ else:
+ fetch = "my_while:1"
+ with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ return sess.run(fetch)
+
+ self.assertAllEqual(Run(20., False), 210.)
+ self.assertAllEqual(Run(20., True), 210.)
+ self.assertAllEqual(Run(100., False), 5050.)
+ self.assertAllEqual(Run(100., True), 5050.)
+
def testWhileError(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g: