aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-06 10:20:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 10:24:30 -0700
commit84f091dff8e1bcd93ac2d69d2cc11faca3790ac9 (patch)
tree9b2a165cdb9b4c514220ac52e98be37aac92e509
parent43a3c393d7a329b7dc7aec02a7d46dc69e5a8ee1 (diff)
Add python test for While op lowering.
Test that fetching values of while outputs in sess.run by tensor name works. This tests that an IdentityN node with the same name and outputs as the original while op was added to the graph during lowering. PiperOrigin-RevId: 211827934
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py35
2 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3026c7755a..58c8975daa 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1634,6 +1634,7 @@ cuda_py_test(
srcs = ["functional_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 3ddb5e06c9..e39daf1371 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import iterator_ops
@@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
+ def testWhileLowering(self):
+
+ def Run(n, fetch_by_name):
+ for use_gpu in (True, False):
+ with ops.Graph().as_default() as g:
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ # outputs: [0, n*(n+1)/2]
+ outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
+
+ # `outputs` is the list of output tensors of the While op. We
+ # arbitrarily choose the 0th tensor to get the While op and set the
+ # lowering attribute on it.
+ outputs[0].op._set_attr("_lower_using_switch_merge",
+ attr_value_pb2.AttrValue(b=True))
+ if not fetch_by_name:
+ fetch = outputs[1]
+ else:
+ fetch = "my_while:1"
+ with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ return sess.run(fetch)
+
+ self.assertAllEqual(Run(20., False), 210.)
+ self.assertAllEqual(Run(20., True), 210.)
+ self.assertAllEqual(Run(100., False), 5050.)
+ self.assertAllEqual(Run(100., True), 5050.)
+
def testWhileError(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g: