aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/iterator_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/iterator_ops.cc')
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc173
1 files changed, 138 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index b476a452a5..da489db7c8 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -662,21 +662,89 @@ class MakeIteratorOp : public OpKernel {
}
};
+// A simple background worker that executes closures asynchronously and without
+// blocking.
+//
+// A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
+// to avoid blocking an executor thread that may be required by the blocking
+// work.
+//
+// NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
+// purpose because its current implementation (in Eigen) uses a finite-length
+// queue and will block the caller when full. This can lead to deadlock under
+// heavy load. Since the number of concurrent work items in each user of a
+// `BackgroundWorker` is at most one per op invocation, the dynamic allocation
+// overhead is tolerable.
+class BackgroundWorker {
+ public:
+ BackgroundWorker(Env* env, const string& name) {
+ thread_.reset(env->StartThread({} /* thread_options */, name,
+ [this]() { WorkerLoop(); }));
+ }
+
+ ~BackgroundWorker() {
+ {
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ }
+ cond_var_.notify_one();
+ // Block until the background thread has terminated.
+ //
+ // NOTE(mrry): We explicitly free and join the thread here because
+ // `WorkerLoop()` uses other members of this object, and so we must join
+ // the thread before destroying them.
+ thread_.reset();
+ }
+
+ void Schedule(std::function<void()> work_item) {
+ {
+ mutex_lock l(mu_);
+ work_queue_.push_back(std::move(work_item));
+ }
+ cond_var_.notify_one();
+ }
+
+ private:
+ void WorkerLoop() {
+ while (true) {
+ std::function<void()> work_item = nullptr;
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ && work_queue_.empty()) {
+ cond_var_.wait(l);
+ }
+ if (cancelled_) {
+ return;
+ }
+ DCHECK(!work_queue_.empty());
+ work_item = std::move(work_queue_.front());
+ work_queue_.pop_front();
+ }
+ DCHECK(work_item != nullptr);
+ work_item();
+ }
+ }
+
+ std::unique_ptr<Thread> thread_;
+ mutex mu_;
+ condition_variable cond_var_;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
+};
+
class ToSingleElementOp : public AsyncOpKernel {
public:
explicit ToSingleElementOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
- thread_pool_(new thread::ThreadPool(
- ctx->env(), ThreadOptions(),
- strings::StrCat("to_single_element_op_thread_",
- SanitizeThreadSuffix(name())),
- 1 /* num_threads */, false /* low_latency_hint */)) {}
+ background_worker_(ctx->env(),
+ strings::StrCat("to_single_element_op_thread_",
+ SanitizeThreadSuffix(name()))) {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
- thread_pool_->Schedule([ctx, done]() {
+ background_worker_.Schedule([ctx, done]() {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
@@ -686,46 +754,60 @@ class ToSingleElementOp : public AsyncOpKernel {
ctx,
dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
done);
+
+ // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
+ // avoid destruction races.
+ IteratorBase* raw_iterator = iterator.release();
+ auto cleanup = gtl::MakeCleanup([ctx, raw_iterator, done] {
+ delete raw_iterator;
+ done();
+ });
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
- bool end_of_sequence;
-
- OP_REQUIRES_OK_ASYNC(
- ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
- done);
- OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
- errors::InvalidArgument("Dataset was empty."), done);
+ bool end_of_sequence = false;
+ Status s =
+ raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ return;
+ }
+ if (end_of_sequence) {
+ ctx->SetStatus(errors::InvalidArgument("Dataset was empty."));
+ return;
+ }
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
components.clear();
- OP_REQUIRES_OK_ASYNC(
- ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
- done);
- OP_REQUIRES_ASYNC(
- ctx, end_of_sequence,
- errors::InvalidArgument("Dataset had more than one element."), done);
-
- done();
+ Status s2 =
+ raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ if (!s2.ok()) {
+ ctx->SetStatus(s2);
+ return;
+ }
+ if (!end_of_sequence) {
+ ctx->SetStatus(
+ errors::InvalidArgument("Dataset had more than one element."));
+ return;
+ }
});
}
private:
- std::unique_ptr<thread::ThreadPool> thread_pool_;
+ BackgroundWorker background_worker_;
};
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
- thread_pool_(new thread::ThreadPool(
- ctx->env(), ThreadOptions(),
+ background_worker_(
+ ctx->env(),
strings::StrCat("one_shot_iterator_initialization_thread_",
- SanitizeThreadSuffix(name())),
- 1 /* num_threads */, false /* low_latency_hint */)),
+ SanitizeThreadSuffix(name()))),
graph_def_version_(ctx->graph_def_version())
{
@@ -767,7 +849,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
if (!initialization_started_) {
// TODO(mrry): Convert the initialization code to use
// callbacks instead of wasting a thread.
- thread_pool_->Schedule([this, ctx, done]() { Init(ctx, done); });
+ background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
initialization_started_ = true;
} else {
done_callbacks_.emplace_back(ctx, std::move(done));
@@ -900,7 +982,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
DataTypeVector output_dtypes_;
std::vector<PartialTensorShape> output_shapes_;
- std::unique_ptr<thread::ThreadPool> thread_pool_;
+ BackgroundWorker background_worker_;
mutex mu_;
ContainerInfo cinfo_ GUARDED_BY(mu_);
@@ -917,11 +999,9 @@ class IteratorGetNextOp : public AsyncOpKernel {
public:
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
- thread_pool_(new thread::ThreadPool(
- ctx->env(), ThreadOptions(),
- strings::StrCat("iterator_get_next_thread_",
- SanitizeThreadSuffix(name())),
- 1 /* num_threads */, false /* low_latency_hint */)) {}
+ background_worker_(ctx->env(),
+ strings::StrCat("iterator_get_next_thread_",
+ SanitizeThreadSuffix(name()))) {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
IteratorResource* iterator;
@@ -930,7 +1010,7 @@ class IteratorGetNextOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
- thread_pool_->Schedule(std::bind(
+ background_worker_.Schedule(std::bind(
[ctx, iterator](DoneCallback done) {
std::vector<Tensor> components;
bool end_of_sequence = false;
@@ -967,7 +1047,7 @@ class IteratorGetNextOp : public AsyncOpKernel {
}
private:
- std::unique_ptr<thread::ThreadPool> thread_pool_;
+ BackgroundWorker background_worker_;
};
class IteratorGetNextSyncOp : public OpKernel {
@@ -1135,22 +1215,45 @@ class DeserializeIteratorOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU),
+ IteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU),
+ IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
+REGISTER_KERNEL_BUILDER(
+ Name("MakeIterator").Device(DEVICE_GPU).HostMemory("dataset"),
+ MakeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU),
AnonymousIteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
+ AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
IteratorGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU),
+ IteratorGetNextOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
IteratorGetNextSyncOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU),
+ IteratorGetNextSyncOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
+ .Device(DEVICE_GPU)
+ .HostMemory("string_handle"),
+ IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2").Device(DEVICE_CPU),
+ IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
+ .Device(DEVICE_GPU)
+ .HostMemory("string_handle"),
+ IteratorFromStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
SerializeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),