aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/dataset_ops.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-18 22:59:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 23:01:53 -0700
commit2a6c5998a239f41926ca295ac20bb595862fd5ff (patch)
tree2291cf24ea993f5efba757eaef0002b2f7971919 /tensorflow/core/ops/dataset_ops.cc
parentee1676d4dbded64e192aecfa693ab605e24c9929 (diff)
[tf.data] Add native implementation for `tf.contrib.data.unbatch()`.
The implementation has two main improvements: 1. Avoid relatively expensive (~15us) function invocation for each incoming batch. 2. Use std::move() where possible to avoid copying strings/variants into the unbatched elements. PiperOrigin-RevId: 193467856
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r--tensorflow/core/ops/dataset_ops.cc7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 57f871af32..8be569b315 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -83,6 +83,13 @@ REGISTER_OP("GeneratorDataset")
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("UnbatchDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("ZipDataset")
.Input("input_datasets: N * variant")
.Output("handle: variant")