aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/cond_v2_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/cond_v2_test.py')
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py372
1 files changed, 367 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 759db5d5f4..97ce245fc8 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
@@ -35,10 +36,12 @@ from tensorflow.python.training import saver
from tensorflow.python.util import compat
-class NewCondTest(test.TestCase):
+class CondV2Test(test.TestCase):
- def _testCond(self, true_fn, false_fn, train_vals):
- with self.test_session() as sess:
+ def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None):
+ if not feed_dict:
+ feed_dict = {}
+ with self.test_session(graph=ops.get_default_graph()) as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
@@ -47,13 +50,17 @@ class NewCondTest(test.TestCase):
expected_grad = gradients_impl.gradients(expected, train_vals)
actual_grad = gradients_impl.gradients(actual, train_vals)
+ sess_run_args = {pred: True}
+ sess_run_args.update(feed_dict)
expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
- (expected, actual, expected_grad, actual_grad), {pred: True})
+ (expected, actual, expected_grad, actual_grad), sess_run_args)
self.assertEqual(expected_val, actual_val)
self.assertEqual(expected_grad_val, actual_grad_val)
+ sess_run_args = {pred: False}
+ sess_run_args.update(feed_dict)
expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
- (expected, actual, expected_grad, actual_grad), {pred: False})
+ (expected, actual, expected_grad, actual_grad), sess_run_args)
self.assertEqual(expected_val, actual_val)
self.assertEqual(expected_grad_val, actual_grad_val)
@@ -131,6 +138,349 @@ class NewCondTest(test.TestCase):
self.assertIn("foo_cond_1_true", ops.get_default_graph()._functions)
self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions)
+ def testDefunInCond(self):
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+
+ @function.Defun()
+ def fn():
+ return x * y * 2.0
+
+ return fn()
+
+ def false_fn():
+ return 2.0
+
+ self._testCond(true_fn, false_fn, [x])
+ self._testCond(true_fn, false_fn, [x, y])
+ self._testCond(true_fn, false_fn, [y])
+
+ def testNestedDefunInCond(self):
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+ return 2.0
+
+ def false_fn():
+
+ @function.Defun()
+ def fn():
+
+ @function.Defun()
+ def nested_fn():
+ return x * y * 2.0
+
+ return nested_fn()
+
+ return fn()
+
+ self._testCond(true_fn, false_fn, [x])
+ self._testCond(true_fn, false_fn, [x, y])
+ self._testCond(true_fn, false_fn, [y])
+
+ def testDoubleNestedDefunInCond(self):
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+
+ @function.Defun()
+ def fn():
+
+ @function.Defun()
+ def nested_fn():
+
+ @function.Defun()
+ def nested_nested_fn():
+ return x * y * 2.0
+
+ return nested_nested_fn()
+
+ return nested_fn()
+
+ return fn()
+
+ def false_fn():
+ return 2.0
+
+ self._testCond(true_fn, false_fn, [x])
+ self._testCond(true_fn, false_fn, [x, y])
+ self._testCond(true_fn, false_fn, [y])
+
+ def testNestedCond(self):
+
+ def run_test(pred_value):
+
+ def build_graph():
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+ return 2.0
+
+ def false_fn():
+
+ def false_true_fn():
+ return x * y * 2.0
+
+ def false_false_fn():
+ return x * 5.0
+
+ return _cond(pred, false_true_fn, false_false_fn, "inside_false_fn")
+
+ return x, y, pred, true_fn, false_fn
+
+ with ops.Graph().as_default():
+ x, y, pred, true_fn, false_fn = build_graph()
+ self._testCond(true_fn, false_fn, [x, y], {pred: pred_value})
+ self._testCond(true_fn, false_fn, [x], {pred: pred_value})
+ self._testCond(true_fn, false_fn, [y], {pred: pred_value})
+
+ run_test(True)
+ run_test(False)
+
+ def testDoubleNestedCond(self):
+
+ def run_test(pred1_value, pred2_value):
+
+ def build_graph():
+ pred1 = array_ops.placeholder(dtypes.bool, name="pred1")
+ pred2 = array_ops.placeholder(dtypes.bool, name="pred2")
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+ return 2.0
+
+ def false_fn():
+
+ def false_true_fn():
+
+ def false_true_true_fn():
+ return x * y * 2.0
+
+ def false_true_false_fn():
+ return x * 10.0
+
+ return _cond(
+ pred1,
+ false_true_true_fn,
+ false_true_false_fn,
+ name="inside_false_true_fn")
+
+ def false_false_fn():
+ return x * 5.0
+
+ return _cond(
+ pred2, false_true_fn, false_false_fn, name="inside_false_fn")
+
+ return x, y, pred1, pred2, true_fn, false_fn
+
+ with ops.Graph().as_default():
+ x, y, pred1, pred2, true_fn, false_fn = build_graph()
+ self._testCond(true_fn, false_fn, [x, y], {
+ pred1: pred1_value,
+ pred2: pred2_value
+ })
+ x, y, pred1, pred2, true_fn, false_fn = build_graph()
+ self._testCond(true_fn, false_fn, [x], {
+ pred1: pred1_value,
+ pred2: pred2_value
+ })
+ x, y, pred1, pred2, true_fn, false_fn = build_graph()
+ self._testCond(true_fn, false_fn, [y], {
+ pred1: pred1_value,
+ pred2: pred2_value
+ })
+
+ run_test(True, True)
+ run_test(True, False)
+ run_test(False, False)
+ run_test(False, True)
+
+ def testGradientFromInsideDefun(self):
+
+ def build_graph():
+ pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
+ pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+ return 2.0
+
+ def false_fn():
+
+ def inner_true_fn():
+ return x * y * 2.0
+
+ def inner_false_fn():
+ return x * 5.0
+
+ return cond_v2.cond_v2(
+ pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")
+
+ cond_outer = cond_v2.cond_v2(
+ pred_outer, true_fn, false_fn, name="outer_cond")
+
+ # Compute grads inside a Defun.
+ @function.Defun()
+ def nesting_fn():
+ return gradients_impl.gradients(cond_outer, [x, y])
+
+ grads = nesting_fn()
+
+ return grads, pred_outer, pred_inner
+
+ with ops.Graph().as_default():
+ grads, pred_outer, pred_inner = build_graph()
+ with self.test_session(graph=ops.get_default_graph()) as sess:
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
+
+ def testGradientFromInsideNestedDefun(self):
+
+ def build_graph():
+ pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
+ pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+ return 2.0
+
+ def false_fn():
+
+ def inner_true_fn():
+ return x * y * 2.0
+
+ def inner_false_fn():
+ return x * 5.0
+
+ return cond_v2.cond_v2(
+ pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")
+
+ cond_outer = cond_v2.cond_v2(
+ pred_outer, true_fn, false_fn, name="outer_cond")
+
+ # Compute grads inside a Defun.
+ @function.Defun()
+ def nesting_fn():
+
+ @function.Defun()
+ def inner_nesting_fn():
+ return gradients_impl.gradients(cond_outer, [x, y])
+
+ return inner_nesting_fn()
+
+ grads = nesting_fn()
+
+ return grads, pred_outer, pred_inner
+
+ with ops.Graph().as_default():
+ grads, pred_outer, pred_inner = build_graph()
+ with self.test_session(graph=ops.get_default_graph()) as sess:
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
+
+ def testBuildCondAndGradientInsideDefun(self):
+
+ def build_graph():
+ pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
+ pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ # Build cond and its gradient inside a Defun.
+ @function.Defun()
+ def fn():
+
+ def true_fn():
+ return 2.0
+
+ def false_fn():
+
+ def inner_true_fn():
+ return x * y * 2.0
+
+ def inner_false_fn():
+ return x * 5.0
+
+ return cond_v2.cond_v2(
+ pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")
+
+ cond_outer = cond_v2.cond_v2(
+ pred_outer, true_fn, false_fn, name="outer_cond")
+ return gradients_impl.gradients(cond_outer, [x, y])
+
+ grads = fn()
+
+ return grads, pred_outer, pred_inner
+
+ with ops.Graph().as_default():
+ grads, pred_outer, pred_inner = build_graph()
+ with self.test_session(graph=ops.get_default_graph()) as sess:
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
+
def testSecondDerivative(self):
with self.test_session() as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
@@ -532,5 +882,17 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
+def _cond(pred, true_fn, false_fn, name):
+ if _is_old_cond():
+ return control_flow_ops.cond(pred, true_fn, false_fn, name=name)
+ else:
+ return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
+
+
+def _is_old_cond():
+ return isinstance(ops.get_default_graph()._get_control_flow_context(),
+ control_flow_ops.CondContext)
+
+
if __name__ == "__main__":
test.main()