diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 12:35:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 12:35:36 -0700 |
commit | c5bd63fd520df4ca2f8159eef212289fb8c3ea6c (patch) | |
tree | c31a99acaaa532d6875b15dfeaaa8695c11ed976 /tensorflow/core | |
parent | 58845f229be9b5ba2e1e36150bff5ba7a85920d8 (diff) | |
parent | e6981fc2225a529427391e98f492eee7bb865988 (diff) |
Merge pull request #20476 from yongtang:06052018-bincount-shape
PiperOrigin-RevId: 215947463
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/ops/math_ops_test.cc | 12 |
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 3eff728f03..a9e5e7824d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount") .Attr("T: {int32, int64, float32, float64}") .Output("bins: T") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(1)); + ShapeHandle unused; + // The input `size` must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + const Tensor* size_tensor = c->input_tensor(1); + if (size_tensor == nullptr) { + // Return unknown shape if size is not known. + c->set_output(0, c->UnknownShapeOfRank(1)); + return Status::OK(); + } + + // Return `[size]` shape if size is known. + int32 size_val = size_tensor->scalar<int32>()(); + if (size_val < 0) { + return errors::InvalidArgument("size (", size_val, + ") must be non-negative"); + } + c->set_output(0, c->MakeShape({size_val})); return Status::OK(); }); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index be4c3ed2b6..05379a7d69 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?"); INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]"); } + +TEST(MathOpsTest, Bincount_ShapeFn) { + ShapeInferenceTestOp op("Bincount"); + + // size should be scalar. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?"); + + INFER_OK(op, "?;?;?", "[?]"); + INFER_OK(op, "?;[];?", "[?]"); + INFER_OK(op, "[?];[];?", "[?]"); + INFER_OK(op, "[?];[];[?]", "[?]"); +} } // end namespace tensorflow |