diff options
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 3eff728f03..a9e5e7824d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount") .Attr("T: {int32, int64, float32, float64}") .Output("bins: T") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(1)); + ShapeHandle unused; + // The input `size` must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + const Tensor* size_tensor = c->input_tensor(1); + if (size_tensor == nullptr) { + // Return unknown shape if size is not known. + c->set_output(0, c->UnknownShapeOfRank(1)); + return Status::OK(); + } + + // Return `[size]` shape if size is known. + int32 size_val = size_tensor->scalar<int32>()(); + if (size_val < 0) { + return errors::InvalidArgument("size (", size_val, + ") must be non-negative"); + } + c->set_output(0, c->MakeShape({size_val})); return Status::OK(); }); |