diff options
author | 2018-01-24 14:03:09 -0800 | |
---|---|---|
committer | 2018-01-24 14:10:01 -0800 | |
commit | 3ec5b5855eb4dc781f471df2f8c981b1f0951ff2 (patch) | |
tree | 907b8fa6f2b5cea12317540808047d8bbd902f50 /tensorflow/core/ops/dataset_ops.cc | |
parent | f7b60fd704ced02f250b2dd0da436cf0ad9c9a8c (diff) |
[Eager mode] Add an Eager-optimized version of the IteratorGetNext op.
Since in eager mode the kernel execution happens on the calling thread
(which would be blocked anyway), we can invoke the iterator
synchronously, and do not need to perform a context switch to and from
a background thread.
Improves the latency benchmarks in
//third_party/tensorflow/contrib/eager/python:datasets_test by
approximately 7us to 13us per element.
PiperOrigin-RevId: 183137488
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 68 |
1 files changed, 32 insertions, 36 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index b86816bb54..2cae814eab 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -409,53 +409,49 @@ REGISTER_OP("OneShotIterator") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); +namespace { + +Status IteratorGetNextShapeFn(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); +} + +} // namespace + REGISTER_OP("IteratorGetNext") .Input("iterator: resource") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector<PartialTensorShape> output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast<int>(i), output_shape_handle); - } - return Status::OK(); - }); + .SetShapeFn(IteratorGetNextShapeFn); + +REGISTER_OP("IteratorGetNextSync") + .Input("iterator: resource") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(IteratorGetNextShapeFn); REGISTER_OP("DatasetToSingleElement") .Input("dataset: variant") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector<PartialTensorShape> output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast<int>(i), output_shape_handle); - } - return Status::OK(); - }); + .SetShapeFn(IteratorGetNextShapeFn); REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") |