aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/experimental_dataset_ops.cc
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-10-09 11:54:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:58:43 -0700
commit072fcb995a3fd658ee2461b59b159498c710513d (patch)
treef3def3d3ac6e270ad32e428889a79d662c8bc9cf /tensorflow/core/ops/experimental_dataset_ops.cc
parent12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (diff)
[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
Diffstat (limited to 'tensorflow/core/ops/experimental_dataset_ops.cc')
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index f6bd5dce26..bbbecc50f8 100644
--- a/tensorflow/core/ops/experimental_dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -138,6 +138,32 @@ REGISTER_OP("ExperimentalAssertNextDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("ExperimentalNumaMapAndBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("batch_size: int64")
+ .Input("num_parallel_calls: int64")
+ .Input("drop_remainder: bool")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Use index from the end to retrieve the Input shapes,
+ // so that to avoid guessing the length of "other_arguments".
+ // batch_size, num_parallel_batches, and drop_remainder are 0-D scalars.
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
+
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("ExperimentalLMDBDataset")
.Input("filenames: string")
.Output("handle: variant")