From 29f596cf21f0332c1e2ece8798fdd9fefd2ba947 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 Jun 2018 14:04:59 +0000 Subject: Improve the shape function of Bincount There was not a lot of restriction in shape function of Bincount and the output shape was unknown. It is actually possible to get a better shape output if `size` input is known. This fix adds enhancement to the shape function of Bincount. Signed-off-by: Yong Tang --- tensorflow/core/ops/math_ops.cc | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'tensorflow/core') diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 1667c398f4..7d0f29368b 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1416,6 +1416,10 @@ REGISTER_OP("Bincount") .Attr("T: {int32, int64, float32, float64}") .Output("bins: T") .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + // The input `size` must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + c->set_output(0, c->UnknownShapeOfRank(1)); return Status::OK(); }); -- cgit v1.2.3 From 740c58b6fa5b6e1c85f688fbda322da0231aa169 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 Jun 2018 14:44:44 +0000 Subject: Return `[size]` shape if size is known for Bincount. Signed-off-by: Yong Tang --- tensorflow/core/ops/math_ops.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'tensorflow/core') diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 7d0f29368b..b57385f63b 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1420,7 +1420,19 @@ REGISTER_OP("Bincount") // The input `size` must be a scalar. TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); - c->set_output(0, c->UnknownShapeOfRank(1)); + const Tensor* size_tensor = c->input_tensor(1); + if (size_tensor == nullptr) { + // Return unknown shape if size is not known. + c->set_output(0, c->UnknownShapeOfRank(1)); + return Status::OK(); + } + + // Return `[size]` shape if size is known. + int32 size_val = size_tensor->scalar()(); + if (size_val < 0) { + return errors::InvalidArgument("size (", size_val, ") must be non-negative"); + } + c->set_output(0, c->MakeShape({size_val})); return Status::OK(); }); -- cgit v1.2.3 From e6981fc2225a529427391e98f492eee7bb865988 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 11 Aug 2018 18:39:13 +0000 Subject: Add additional test cases for Bincount Shape function, and fix clang-format issue Signed-off-by: Yong Tang --- tensorflow/core/ops/math_ops.cc | 3 ++- tensorflow/core/ops/math_ops_test.cc | 12 ++++++++++++ tensorflow/python/kernel_tests/bincount_op_test.py | 19 +++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) (limited to 'tensorflow/core') diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index b57385f63b..0ba4a9a005 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1430,7 +1430,8 @@ REGISTER_OP("Bincount") // Return `[size]` shape if size is known. int32 size_val = size_tensor->scalar()(); if (size_val < 0) { - return errors::InvalidArgument("size (", size_val, ") must be non-negative"); + return errors::InvalidArgument("size (", size_val, + ") must be non-negative"); } c->set_output(0, c->MakeShape({size_val})); return Status::OK(); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 23f1538912..7bf7c476f4 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -558,4 +558,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?"); INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]"); } + +TEST(MathOpsTest, Bincount_ShapeFn) { + ShapeInferenceTestOp op("Bincount"); + + // size should be scalar. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?"); + + INFER_OK(op, "?;?;?", "[?]"); + INFER_OK(op, "?;[];?", "[?]"); + INFER_OK(op, "[?];[];?", "[?]"); + INFER_OK(op, "[?];[];[?]", "[?]"); +} } // end namespace tensorflow diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py index 2767df127e..15d9de56db 100644 --- a/tensorflow/python/kernel_tests/bincount_op_test.py +++ b/tensorflow/python/kernel_tests/bincount_op_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -97,6 +99,23 @@ class BincountTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.InvalidArgumentError): math_ops.bincount([1, 2, 3, -1, 6, 8]).eval() + def test_shape_function(self): + # size must be scalar. + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"): + gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], []) + # size must be positive. + with self.assertRaisesRegexp( + ValueError, "must be non-negative"): + gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, []) + # if size is a constant then the shape is known. + v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, []) + self.assertAllEqual(v1.get_shape().as_list(), [5]) + # if size is a placeholder then the shape is unknown. + s = array_ops.placeholder(dtype=dtypes.int32) + v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, []) + self.assertAllEqual(v2.get_shape().as_list(), [None]) + if __name__ == "__main__": googletest.main() -- cgit v1.2.3