aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:35:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:35:36 -0700
commitc5bd63fd520df4ca2f8159eef212289fb8c3ea6c (patch)
treec31a99acaaa532d6875b15dfeaaa8695c11ed976 /tensorflow/core
parent58845f229be9b5ba2e1e36150bff5ba7a85920d8 (diff)
parente6981fc2225a529427391e98f492eee7bb865988 (diff)
Merge pull request #20476 from yongtang:06052018-bincount-shape
PiperOrigin-RevId: 215947463
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/ops/math_ops.cc19
-rw-r--r--tensorflow/core/ops/math_ops_test.cc12
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3eff728f03..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount")
.Attr("T: {int32, int64, float32, float64}")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(1));
+ ShapeHandle unused;
+ // The input `size` must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ const Tensor* size_tensor = c->input_tensor(1);
+ if (size_tensor == nullptr) {
+ // Return unknown shape if size is not known.
+ c->set_output(0, c->UnknownShapeOfRank(1));
+ return Status::OK();
+ }
+
+ // Return `[size]` shape if size is known.
+ int32 size_val = size_tensor->scalar<int32>()();
+ if (size_val < 0) {
+ return errors::InvalidArgument("size (", size_val,
+ ") must be non-negative");
+ }
+ c->set_output(0, c->MakeShape({size_val}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index be4c3ed2b6..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
+
+TEST(MathOpsTest, Bincount_ShapeFn) {
+ ShapeInferenceTestOp op("Bincount");
+
+ // size should be scalar.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
+
+ INFER_OK(op, "?;?;?", "[?]");
+ INFER_OK(op, "?;[];?", "[?]");
+ INFER_OK(op, "[?];[];?", "[?]");
+ INFER_OK(op, "[?];[];[?]", "[?]");
+}
} // end namespace tensorflow