diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-10-01 13:40:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 13:44:35 -0700 |
commit | c86f5941359526b91d85daf844e94ff5d39b2d6c (patch) | |
tree | af0e32582187d30a58e1da7c6e18f01ccb701c36 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 1630584951975479dee852cf6f7603fe6819fde1 (diff) |
Make cond_v2 If op lowering work in a defun + eager.
Prior to this change, the lowering pass assumed that the If op
functions would be available in the If op's graph. If the If op is
defined in a defun and then called via eager execution, the functions
will be in the eager context, but not in the defun's graph. This
change makes the lowering pass correctly use the function library
passed in by the caller via GraphOptimizationPassOptions.
PiperOrigin-RevId: 215271990
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index d91a848e01..ae61be614e 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -31,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -3414,6 +3415,27 @@ class EagerTest(test.TestCase): self.assertAllEqual(r.numpy(), 10) self.assertFalse(isinstance(r, list)) + def testCondInDefun(self): + if "GPU" in [d.device_type for d in device_lib.list_local_devices()]: + return unittest.skip("b/113346829 (gpu failure)") + + with context.eager_mode(): + + @eager_function.defun + def foo(pred): + # TODO(b/111124878): this only needs to output one element. + fn1 = lambda: (constant_op.constant(10), constant_op.constant(100)) + fn2 = lambda: (constant_op.constant(20), constant_op.constant(200)) + return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2) + + r = foo(True) + self.assertAllEqual(r[0].numpy(), 10) + self.assertNotIsInstance(r, list) + + r = foo(False) + self.assertAllEqual(r[0].numpy(), 20) + self.assertFalse(isinstance(r, list)) + def testWhileLoop(self): with context.eager_mode(): tensor = constant_op.constant([1, 2, 3, 4, 5]) |