diff options
Diffstat (limited to 'tensorflow/core/kernels/data/iterator_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/data/iterator_ops.cc | 173 |
1 files changed, 138 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index b476a452a5..da489db7c8 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -662,21 +662,89 @@ class MakeIteratorOp : public OpKernel { } }; +// A simple background worker that executes closures asynchronously and without +// blocking. +// +// A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel` +// to avoid blocking an executor thread that may be required by the blocking +// work. +// +// NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this +// purpose because its current implementation (in Eigen) uses a finite-length +// queue and will block the caller when full. This can lead to deadlock under +// heavy load. Since the number of concurrent work items in each user of a +// `BackgroundWorker` is at most one per op invocation, the dynamic allocation +// overhead is tolerable. +class BackgroundWorker { + public: + BackgroundWorker(Env* env, const string& name) { + thread_.reset(env->StartThread({} /* thread_options */, name, + [this]() { WorkerLoop(); })); + } + + ~BackgroundWorker() { + { + mutex_lock l(mu_); + cancelled_ = true; + } + cond_var_.notify_one(); + // Block until the background thread has terminated. + // + // NOTE(mrry): We explicitly free and join the thread here because + // `WorkerLoop()` uses other members of this object, and so we must join + // the thread before destroying them. + thread_.reset(); + } + + void Schedule(std::function<void()> work_item) { + { + mutex_lock l(mu_); + work_queue_.push_back(std::move(work_item)); + } + cond_var_.notify_one(); + } + + private: + void WorkerLoop() { + while (true) { + std::function<void()> work_item = nullptr; + { + mutex_lock l(mu_); + while (!cancelled_ && work_queue_.empty()) { + cond_var_.wait(l); + } + if (cancelled_) { + return; + } + DCHECK(!work_queue_.empty()); + work_item = std::move(work_queue_.front()); + work_queue_.pop_front(); + } + DCHECK(work_item != nullptr); + work_item(); + } + } + + std::unique_ptr<Thread> thread_; + mutex mu_; + condition_variable cond_var_; + bool cancelled_ GUARDED_BY(mu_) = false; + std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_); +}; + class ToSingleElementOp : public AsyncOpKernel { public: explicit ToSingleElementOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - thread_pool_(new thread::ThreadPool( - ctx->env(), ThreadOptions(), - strings::StrCat("to_single_element_op_thread_", - SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) {} + background_worker_(ctx->env(), + strings::StrCat("to_single_element_op_thread_", + SanitizeThreadSuffix(name()))) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. - thread_pool_->Schedule([ctx, done]() { + background_worker_.Schedule([ctx, done]() { DatasetBase* dataset; OP_REQUIRES_OK_ASYNC( ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); @@ -686,46 +754,60 @@ class ToSingleElementOp : public AsyncOpKernel { ctx, dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator), done); + + // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to + // avoid destruction races. + IteratorBase* raw_iterator = iterator.release(); + auto cleanup = gtl::MakeCleanup([ctx, raw_iterator, done] { + delete raw_iterator; + done(); + }); std::vector<Tensor> components; components.reserve(dataset->output_dtypes().size()); - bool end_of_sequence; - - OP_REQUIRES_OK_ASYNC( - ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), - done); - OP_REQUIRES_ASYNC(ctx, !end_of_sequence, - errors::InvalidArgument("Dataset was empty."), done); + bool end_of_sequence = false; + Status s = + raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence); + if (!s.ok()) { + ctx->SetStatus(s); + return; + } + if (end_of_sequence) { + ctx->SetStatus(errors::InvalidArgument("Dataset was empty.")); + return; + } for (int i = 0; i < components.size(); ++i) { // TODO(mrry): Check that the shapes match the shape attrs. ctx->set_output(i, components[i]); } components.clear(); - OP_REQUIRES_OK_ASYNC( - ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), - done); - OP_REQUIRES_ASYNC( - ctx, end_of_sequence, - errors::InvalidArgument("Dataset had more than one element."), done); - - done(); + Status s2 = + raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence); + if (!s2.ok()) { + ctx->SetStatus(s2); + return; + } + if (!end_of_sequence) { + ctx->SetStatus( + errors::InvalidArgument("Dataset had more than one element.")); + return; + } }); } private: - std::unique_ptr<thread::ThreadPool> thread_pool_; + BackgroundWorker background_worker_; }; class OneShotIteratorOp : public AsyncOpKernel { public: explicit OneShotIteratorOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - thread_pool_(new thread::ThreadPool( - ctx->env(), ThreadOptions(), + background_worker_( + ctx->env(), strings::StrCat("one_shot_iterator_initialization_thread_", - SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)), + SanitizeThreadSuffix(name()))), graph_def_version_(ctx->graph_def_version()) { @@ -767,7 +849,7 @@ class OneShotIteratorOp : public AsyncOpKernel { if (!initialization_started_) { // TODO(mrry): Convert the initialization code to use // callbacks instead of wasting a thread. - thread_pool_->Schedule([this, ctx, done]() { Init(ctx, done); }); + background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); }); initialization_started_ = true; } else { done_callbacks_.emplace_back(ctx, std::move(done)); @@ -900,7 +982,7 @@ class OneShotIteratorOp : public AsyncOpKernel { DataTypeVector output_dtypes_; std::vector<PartialTensorShape> output_shapes_; - std::unique_ptr<thread::ThreadPool> thread_pool_; + BackgroundWorker background_worker_; mutex mu_; ContainerInfo cinfo_ GUARDED_BY(mu_); @@ -917,11 +999,9 @@ class IteratorGetNextOp : public AsyncOpKernel { public: explicit IteratorGetNextOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - thread_pool_(new thread::ThreadPool( - ctx->env(), ThreadOptions(), - strings::StrCat("iterator_get_next_thread_", - SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) {} + background_worker_(ctx->env(), + strings::StrCat("iterator_get_next_thread_", + SanitizeThreadSuffix(name()))) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { IteratorResource* iterator; @@ -930,7 +1010,7 @@ class IteratorGetNextOp : public AsyncOpKernel { // The call to `iterator->GetNext()` may block and depend on an // inter-op thread pool thread, so we issue the call from the // owned thread pool. - thread_pool_->Schedule(std::bind( + background_worker_.Schedule(std::bind( [ctx, iterator](DoneCallback done) { std::vector<Tensor> components; bool end_of_sequence = false; @@ -967,7 +1047,7 @@ class IteratorGetNextOp : public AsyncOpKernel { } private: - std::unique_ptr<thread::ThreadPool> thread_pool_; + BackgroundWorker background_worker_; }; class IteratorGetNextSyncOp : public OpKernel { @@ -1135,22 +1215,45 @@ class DeserializeIteratorOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); +REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU), + IteratorHandleOp); +REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU), + IteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU), MakeIteratorOp); +REGISTER_KERNEL_BUILDER( + Name("MakeIterator").Device(DEVICE_GPU).HostMemory("dataset"), + MakeIteratorOp); REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU), AnonymousIteratorHandleOp); +REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU), + AnonymousIteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU), ToSingleElementOp); REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), OneShotIteratorOp); REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU), IteratorGetNextOp); +REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU), + IteratorGetNextOp); REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU), IteratorGetNextSyncOp); +REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU), + IteratorGetNextSyncOp); REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU), IteratorToStringHandleOp); +REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") + .Device(DEVICE_GPU) + .HostMemory("string_handle"), + IteratorToStringHandleOp); REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU), IteratorFromStringHandleOp); +REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2").Device(DEVICE_CPU), + IteratorFromStringHandleOp); +REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") + .Device(DEVICE_GPU) + .HostMemory("string_handle"), + IteratorFromStringHandleOp); REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), SerializeIteratorOp); REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), |