diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 12:35:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 12:35:36 -0700 |
commit | c5bd63fd520df4ca2f8159eef212289fb8c3ea6c (patch) | |
tree | c31a99acaaa532d6875b15dfeaaa8695c11ed976 /tensorflow/python/kernel_tests | |
parent | 58845f229be9b5ba2e1e36150bff5ba7a85920d8 (diff) | |
parent | e6981fc2225a529427391e98f492eee7bb865988 (diff) |
Merge pull request #20476 from yongtang:06052018-bincount-shape
PiperOrigin-RevId: 215947463
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/bincount_op_test.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py index 8a58b3f97e..8177cdd454 100644 --- a/tensorflow/python/kernel_tests/bincount_op_test.py +++ b/tensorflow/python/kernel_tests/bincount_op_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -97,6 +99,22 @@ class BincountTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.InvalidArgumentError): math_ops.bincount([1, 2, 3, -1, 6, 8]).eval() + def test_shape_function(self): + # size must be scalar. + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"): + gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], []) + # size must be positive. + with self.assertRaisesRegexp(ValueError, "must be non-negative"): + gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, []) + # if size is a constant then the shape is known. + v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, []) + self.assertAllEqual(v1.get_shape().as_list(), [5]) + # if size is a placeholder then the shape is unknown. + s = array_ops.placeholder(dtype=dtypes.int32) + v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, []) + self.assertAllEqual(v2.get_shape().as_list(), [None]) + if __name__ == "__main__": googletest.main() |