aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/ops/math_ops.cc19
-rw-r--r--tensorflow/core/ops/math_ops_test.cc12
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3eff728f03..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount")
.Attr("T: {int32, int64, float32, float64}")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(1));
+ ShapeHandle unused;
+ // The input `size` must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ const Tensor* size_tensor = c->input_tensor(1);
+ if (size_tensor == nullptr) {
+ // Return unknown shape if size is not known.
+ c->set_output(0, c->UnknownShapeOfRank(1));
+ return Status::OK();
+ }
+
+ // Return `[size]` shape if size is known.
+ int32 size_val = size_tensor->scalar<int32>()();
+ if (size_val < 0) {
+ return errors::InvalidArgument("size (", size_val,
+ ") must be non-negative");
+ }
+ c->set_output(0, c->MakeShape({size_val}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index be4c3ed2b6..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
+
+TEST(MathOpsTest, Bincount_ShapeFn) {
+ ShapeInferenceTestOp op("Bincount");
+
+ // size should be scalar.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
+
+ INFER_OK(op, "?;?;?", "[?]");
+ INFER_OK(op, "?;[];?", "[?]");
+ INFER_OK(op, "[?];[];?", "[?]");
+ INFER_OK(op, "[?];[];[?]", "[?]");
+}
} // end namespace tensorflow