aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-10-01 13:40:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 13:44:35 -0700
commitc86f5941359526b91d85daf844e94ff5d39b2d6c (patch)
treeaf0e32582187d30a58e1da7c6e18f01ccb701c36 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent1630584951975479dee852cf6f7603fe6819fde1 (diff)
Make cond_v2 If op lowering work in a defun + eager.
Prior to this change, the lowering pass assumed that the If op functions would be available in the If op's graph. If the If op is defined in a defun and then called via eager execution, the functions will be in the eager context, but not in the defun's graph. This change makes the lowering pass correctly use the function library passed in by the caller via GraphOptimizationPassOptions. PiperOrigin-RevId: 215271990
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index d91a848e01..ae61be614e 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -31,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -3414,6 +3415,27 @@ class EagerTest(test.TestCase):
self.assertAllEqual(r.numpy(), 10)
self.assertFalse(isinstance(r, list))
+ def testCondInDefun(self):
+ if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
+ return unittest.skip("b/113346829 (gpu failure)")
+
+ with context.eager_mode():
+
+ @eager_function.defun
+ def foo(pred):
+ # TODO(b/111124878): this only needs to output one element.
+ fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
+ fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
+ return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
+
+ r = foo(True)
+ self.assertAllEqual(r[0].numpy(), 10)
+ self.assertNotIsInstance(r, list)
+
+ r = foo(False)
+ self.assertAllEqual(r[0].numpy(), 20)
+ self.assertFalse(isinstance(r, list))
+
def testWhileLoop(self):
with context.eager_mode():
tensor = constant_op.constant([1, 2, 3, 4, 5])