diff options
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 15e0ca8af9..9dca5f53ce 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -218,7 +218,17 @@ REGISTER_OP("MapAndBatchDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + // Use index from the end to retrieve the Input shapes, + // so that to avoid guessing the length of "other_arguments". + // batch_size, num_parallel_batches, and drop_remainder are 0-D scalars. + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); + + return shape_inference::ScalarShape(c); + }); REGISTER_OP("MapAndBatchDatasetV2") .Input("input_dataset: variant") @@ -231,7 +241,17 @@ REGISTER_OP("MapAndBatchDatasetV2") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + // Use index from the end to retrieve the Input shapes, + // so that to avoid guessing the length of "other_arguments". + // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars. + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); + + return shape_inference::ScalarShape(c); + }); REGISTER_OP("PrefetchDataset") .Input("input_dataset: variant") |