aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/dataset_ops.cc
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-18 21:00:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 21:04:00 -0700
commit6070ae0e148f50dbc8f36e1654f0a3f53b8b067e (patch)
tree165e4c050220180a76512e304b70eee0cd02a2db /tensorflow/core/ops/dataset_ops.cc
parent60b78d6152e6f8d985f3086930ff986c140c36bf (diff)
Merge changes from github.
PiperOrigin-RevId: 201110240
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r--tensorflow/core/ops/dataset_ops.cc24
1 files changed, 22 insertions, 2 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 15e0ca8af9..9dca5f53ce 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -218,7 +218,17 @@ REGISTER_OP("MapAndBatchDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Use index from the end to retrieve the Input shapes,
+ // so that to avoid guessing the length of "other_arguments".
+ // batch_size, num_parallel_batches, and drop_remainder are 0-D scalars.
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
+
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("MapAndBatchDatasetV2")
.Input("input_dataset: variant")
@@ -231,7 +241,17 @@ REGISTER_OP("MapAndBatchDatasetV2")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Use index from the end to retrieve the Input shapes,
+ // so that to avoid guessing the length of "other_arguments".
+ // batch_size, num_parallel_calls, and drop_remainder are 0-D scalars.
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
+
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("PrefetchDataset")
.Input("input_dataset: variant")