aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/iterator_ops.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-01-26 10:09:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 10:14:02 -0800
commita7d4f82660e1e0bfb8b4bd3a4378e389795b7f9e (patch)
tree7329b1cb7f0a5086a964753d14ae7c72e54330d9 /tensorflow/core/kernels/data/iterator_ops.cc
parentc1338b14149b6313280bea455ec1dec2a336bd31 (diff)
[tf.data] Move slow-path-related code into the slow path in IteratorHandleOp::Compute().
This slightly reduces the amount of work performed when an iterator is accessed (after the first access), and potentially reduces contention if concurrent steps are accessing the same iterator. PiperOrigin-RevId: 183406221
Diffstat (limited to 'tensorflow/core/kernels/data/iterator_ops.cc')
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc95
1 files changed, 43 insertions, 52 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 56044a3d41..ca22f10a85 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -430,13 +430,10 @@ class IteratorStateVariant {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
kIteratorVariantTypeName);
-// TODO(mrry): Can we simply use the template kernel here?
class IteratorHandleOp : public OpKernel {
public:
explicit IteratorHandleOp(OpKernelConstruction* ctx)
: OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
- OP_REQUIRES_OK(ctx, ctx->allocate_persistent(DT_STRING, TensorShape({2}),
- &handle_, nullptr));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
@@ -460,56 +457,51 @@ class IteratorHandleOp : public OpKernel {
}
void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- FunctionLibraryRuntime* lib = context->function_library();
- std::unique_ptr<DeviceMgr> device_mgr(nullptr);
- std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
- // If the iterator is shared then we construct a new FLR, and pass that in.
- // NOTE(mrry,rohanj): In this case it is not possible to call remote
- // functions from the iterator. We may add this functionality if there
- // is sufficient demand, but it will require a significant refactoring.
- if (!name_.empty()) {
- lib = CreateFLR(context, &device_mgr, &flib_def, &pflr);
- }
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ FunctionLibraryRuntime* lib = context->function_library();
+ std::unique_ptr<DeviceMgr> device_mgr(nullptr);
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ // If the iterator is shared then we construct a new FLR, and pass that
+ // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
+ // functions from the iterator. We may add this functionality if there
+ // is sufficient demand, but it will require a significant refactoring.
+ if (!name_.empty()) {
+ lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
+ }
- if (resource_ == nullptr) {
- ResourceMgr* mgr = context->resource_manager();
- OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ IteratorResource* resource;
+ OP_REQUIRES_OK(
+ context,
+ mgr->LookupOrCreate<IteratorResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [lib, &device_mgr, &flib_def, &pflr,
+ this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new IteratorResource(
+ output_dtypes_, output_shapes_, graph_def_version_,
+ std::move(device_mgr), std::move(flib_def),
+ std::move(pflr), lib);
+ return Status::OK();
+ }));
+
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
- IteratorResource* resource;
- OP_REQUIRES_OK(
- context,
- mgr->LookupOrCreate<IteratorResource>(
- cinfo_.container(), cinfo_.name(), &resource,
- [lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- *ret = new IteratorResource(
- output_dtypes_, output_shapes_, graph_def_version_,
- std::move(device_mgr), std::move(flib_def),
- std::move(pflr), lib);
- return Status::OK();
- }));
-
- Status s = VerifyResource(resource);
- if (TF_PREDICT_FALSE(!s.ok())) {
- resource->Unref();
- context->SetStatus(s);
- return;
+ resource_ = resource;
}
-
- auto h = handle_.AccessTensor(context)->template flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
- resource_ = resource;
- }
- if (context->expected_output_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
- context, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<IteratorResource>()));
- } else {
- context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
}
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<IteratorResource>()));
}
private:
@@ -526,7 +518,7 @@ class IteratorHandleOp : public OpKernel {
return Status::OK();
}
- FunctionLibraryRuntime* CreateFLR(
+ FunctionLibraryRuntime* CreatePrivateFLR(
OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
@@ -546,9 +538,8 @@ class IteratorHandleOp : public OpKernel {
}
mutex mu_;
- ContainerInfo cinfo_ GUARDED_BY(mu_);
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
IteratorResource* resource_ GUARDED_BY(mu_) = nullptr;
- PersistentTensor handle_ GUARDED_BY(mu_);
DataTypeVector output_dtypes_;
std::vector<PartialTensorShape> output_shapes_;
const int graph_def_version_;