aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/dataset_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r--tensorflow/core/ops/dataset_ops.cc7
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index e2453b9712..2852c49e19 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -105,8 +105,11 @@ REGISTER_OP("RepeatDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate the
- // shape of `count`.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle count_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("TakeDataset")
.Input("input_dataset: variant")