diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-06 10:20:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 10:24:30 -0700 |
commit | 84f091dff8e1bcd93ac2d69d2cc11faca3790ac9 (patch) | |
tree | 9b2a165cdb9b4c514220ac52e98be37aac92e509 | |
parent | 43a3c393d7a329b7dc7aec02a7d46dc69e5a8ee1 (diff) |
Add python test for While op lowering.
Test that fetching values of while outputs in sess.run by tensor name works. This tests that an IdentityN node with the same name and outputs as the original while op was added to the graph during lowering.
PiperOrigin-RevId: 211827934
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/functional_ops_test.py | 35 |
2 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3026c7755a..58c8975daa 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1634,6 +1634,7 @@ cuda_py_test( srcs = ["functional_ops_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 3ddb5e06c9..e39daf1371 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import iterator_ops @@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(Run(sess, 20.), 210.) self.assertAllEqual(Run(sess, 100.), 5050.) + def testWhileLowering(self): + + def Run(n, fetch_by_name): + for use_gpu in (True, False): + with ops.Graph().as_default() as g: + + @function.Defun(*[dtypes.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + # outputs: [0, n*(n+1)/2] + outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while") + + # `outputs` is the list of output tensors of the While op. We + # arbitrarily choose the 0th tensor to get the While op and set the + # lowering attribute on it. + outputs[0].op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + if not fetch_by_name: + fetch = outputs[1] + else: + fetch = "my_while:1" + with self.test_session(graph=g, use_gpu=use_gpu) as sess: + return sess.run(fetch) + + self.assertAllEqual(Run(20., False), 210.) + self.assertAllEqual(Run(20., True), 210.) + self.assertAllEqual(Run(100., False), 5050.) + self.assertAllEqual(Run(100., True), 5050.) + def testWhileError(self): for use_gpu in (True, False): with ops.Graph().as_default() as g: |