diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-02 17:57:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 18:01:17 -0700 |
commit | 9f7a138640408cea58698a432fd1596cf436b484 (patch) | |
tree | d3f66d44d654333c94ebbfec002858e8238ac583 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | b7e9cbab27c893283acc4a6154d7a59dffb23758 (diff) |
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 07ec859766..a1be77601c 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -351,6 +351,13 @@ class ControlFlowTest(test.TestCase): grad = gradients_impl.gradients(y, [v]) self.assertAllEqual([None], grad) + def testCondOutputShape(self): + x = constant_op.constant(1.0) + b = control_flow_ops.cond( + constant_op.constant(True), lambda: math_ops.square(x), + lambda: math_ops.subtract(x, 1.)) + self.assertEqual(b.shape, tensor_shape.scalar()) + def testFetchable(self): with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32) |