diff options
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops_test.py')
-rw-r--r-- | tensorflow/python/ops/control_flow_ops_test.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index f22f3059d1..adc8c51e11 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -350,6 +350,42 @@ class SwitchTestCase(test_util.TensorFlowTestCase): @test_util.with_c_api +class SmartCondTest(test_util.TensorFlowTestCase): + + def testSmartCondTrue(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 5)) + self.assertEqual(z.eval(), 32) + + def testSmartCondFalse(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(4) + y = constant_op.constant(3) + z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 3)) + self.assertEqual(z.eval(), 9) + + def testSmartCondMissingArg1(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.smart_cond(True, false_fn=lambda: x) + + def testSmartCondMissingArg2(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.smart_cond(True, lambda: x) + + +@test_util.with_c_api class CondTest(test_util.TensorFlowTestCase): def testCondTrue(self): |