diff options
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index e2453b9712..2852c49e19 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -105,8 +105,11 @@ REGISTER_OP("RepeatDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate the - // shape of `count`. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle count_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); + return shape_inference::ScalarShape(c); + }); REGISTER_OP("TakeDataset") .Input("input_dataset: variant") |