aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/reduction_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-10 13:08:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-10 14:12:05 -0700
commitc17952f52e61ef701c4ea4b3348a8586f40f3744 (patch)
tree6ff40887e07f9a1d775d95c42cd625be1c98d624 /tensorflow/python/kernel_tests/reduction_ops_test.py
parentf4cd3ec814ffcd1394f9446efd4fcdb90a4d8d35 (diff)
Add more tests for math_ops.reduced_shape.
Change: 121985479
Diffstat (limited to 'tensorflow/python/kernel_tests/reduction_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index cbf4ee1e61..d330040db4 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -27,6 +27,18 @@ from tensorflow.python.ops import math_ops
class ReducedShapeTest(tf.test.TestCase):
+ def testSimple(self):
+ with self.test_session():
+ def check(shape, axes, result):
+ output = math_ops.reduced_shape(shape, axes=axes)
+ self.assertAllEqual(output.eval(), result)
+ check([3], [], [3])
+ check([3], [0], [1])
+ check([5, 3], [], [5, 3])
+ check([5, 3], [0], [1, 3])
+ check([5, 3], [1], [5, 1])
+ check([5, 3], [0, 1], [1, 1])
+
def testZeros(self):
"""Check that reduced_shape does the right thing with zero dimensions."""
with self.test_session():