aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops_test.py')
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index f22f3059d1..adc8c51e11 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -350,6 +350,42 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
@test_util.with_c_api
+class SmartCondTest(test_util.TensorFlowTestCase):
+
+ def testSmartCondTrue(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 5))
+ self.assertEqual(z.eval(), 32)
+
+ def testSmartCondFalse(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(4)
+ y = constant_op.constant(3)
+ z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 3))
+ self.assertEqual(z.eval(), 9)
+
+ def testSmartCondMissingArg1(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ control_flow_ops.smart_cond(True, false_fn=lambda: x)
+
+ def testSmartCondMissingArg2(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ control_flow_ops.smart_cond(True, lambda: x)
+
+
+@test_util.with_c_api
class CondTest(test_util.TensorFlowTestCase):
def testCondTrue(self):