aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/dataset_ops.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-01-24 14:03:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 14:10:01 -0800
commit3ec5b5855eb4dc781f471df2f8c981b1f0951ff2 (patch)
tree907b8fa6f2b5cea12317540808047d8bbd902f50 /tensorflow/core/ops/dataset_ops.cc
parentf7b60fd704ced02f250b2dd0da436cf0ad9c9a8c (diff)
[Eager mode] Add an Eager-optimized version of the IteratorGetNext op.
Since in eager mode the kernel execution happens on the calling thread (which would be blocked anyway), we can invoke the iterator synchronously, and do not need to perform a context switch to and from a background thread. Improves the latency benchmarks in //third_party/tensorflow/contrib/eager/python:datasets_test by approximately 7us to 13us per element. PiperOrigin-RevId: 183137488
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r--tensorflow/core/ops/dataset_ops.cc68
1 files changed, 32 insertions, 36 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index b86816bb54..2cae814eab 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -409,53 +409,49 @@ REGISTER_OP("OneShotIterator")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
+namespace {
+
+Status IteratorGetNextShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("IteratorGetNext")
.Input("iterator: resource")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
- });
+ .SetShapeFn(IteratorGetNextShapeFn);
+
+REGISTER_OP("IteratorGetNextSync")
+ .Input("iterator: resource")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(IteratorGetNextShapeFn);
REGISTER_OP("DatasetToSingleElement")
.Input("dataset: variant")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
- });
+ .SetShapeFn(IteratorGetNextShapeFn);
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")