aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-02 17:57:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 18:01:17 -0700
commit9f7a138640408cea58698a432fd1596cf436b484 (patch)
treed3f66d44d654333c94ebbfec002858e8238ac583 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parentb7e9cbab27c893283acc4a6154d7a59dffb23758 (diff)
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 07ec859766..a1be77601c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -351,6 +351,13 @@ class ControlFlowTest(test.TestCase):
grad = gradients_impl.gradients(y, [v])
self.assertAllEqual([None], grad)
+ def testCondOutputShape(self):
+ x = constant_op.constant(1.0)
+ b = control_flow_ops.cond(
+ constant_op.constant(True), lambda: math_ops.square(x),
+ lambda: math_ops.subtract(x, 1.))
+ self.assertEqual(b.shape, tensor_shape.scalar())
+
def testFetchable(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)