aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:35:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:35:36 -0700
commitc5bd63fd520df4ca2f8159eef212289fb8c3ea6c (patch)
treec31a99acaaa532d6875b15dfeaaa8695c11ed976 /tensorflow/python/kernel_tests
parent58845f229be9b5ba2e1e36150bff5ba7a85920d8 (diff)
parente6981fc2225a529427391e98f492eee7bb865988 (diff)
Merge pull request #20476 from yongtang:06052018-bincount-shape
PiperOrigin-RevId: 215947463
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/bincount_op_test.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 8a58b3f97e..8177cdd454 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@@ -97,6 +99,22 @@ class BincountTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError):
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
+ def test_shape_function(self):
+ # size must be scalar.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
+ # size must be positive.
+ with self.assertRaisesRegexp(ValueError, "must be non-negative"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
+ # if size is a constant then the shape is known.
+ v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
+ self.assertAllEqual(v1.get_shape().as_list(), [5])
+ # if size is a placeholder then the shape is unknown.
+ s = array_ops.placeholder(dtype=dtypes.int32)
+ v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
+ self.assertAllEqual(v2.get_shape().as_list(), [None])
+
if __name__ == "__main__":
googletest.main()