aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-08 15:27:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 15:36:28 -0700
commiteb0f862ba60f41e8d0f06ceb6fc65f7f9905a25a (patch)
tree7924e7e57d5ccdf9f1ff2a74a79ac82c6e4d5a49
parenta991acba07ce6c5903ee84e4a72d3d59e22b77fc (diff)
Automated rollback of commit 13b47e6c4f9d7b295948b1057139bf676e394b6f
PiperOrigin-RevId: 216260575
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc27
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc9
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc12
7 files changed, 37 insertions, 44 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 8acd6cc724..7a833668ac 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
+#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
@@ -25,11 +27,13 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 0fb721cd7c..f45a239793 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -445,10 +445,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
- runner_thread_ =
- MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
- runner_thread_->Schedule(
- std::bind(&Iterator::RunnerThread, this, ctx_copy));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ std::bind(&Iterator::RunnerThread, this, ctx_copy)));
}
}
@@ -704,7 +703,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
- std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
};
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index 859df57962..9aa505f4f1 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -127,10 +126,9 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!optimize_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- optimize_thread_ =
- MakeUnique<BackgroundWorker>(ctx->env(), "optimize_thread");
- optimize_thread_->Schedule(
- [this, new_ctx]() { OptimizeThread(new_ctx); });
+ optimize_thread_.reset(ctx->env()->StartThread(
+ {}, "optimize_thread",
+ [this, new_ctx]() { OptimizeThread(new_ctx); }));
}
return Status::OK();
}
@@ -169,7 +167,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
mutex mu_;
condition_variable cond_var_;
std::shared_ptr<model::Model> model_;
- std::unique_ptr<BackgroundWorker> optimize_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 9c836b836e..6b6b3d6ab9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -482,10 +481,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- worker_threads_.emplace_back(
- MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
- worker_threads_.back()->Schedule(
- [this, new_ctx, i]() { WorkerThread(new_ctx, i); });
+ worker_threads_.emplace_back(ctx->env()->StartThread(
+ {}, "worker_thread",
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -582,10 +580,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- worker_threads_.emplace_back(
- MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
- worker_threads_.back()->Schedule(
- [this, new_ctx, i]() { WorkerThread(new_ctx, i); });
+ worker_threads_.emplace_back(ctx->env()->StartThread(
+ {}, "worker_thread",
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -1050,8 +1047,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads.
- std::vector<std::unique_ptr<BackgroundWorker>> worker_threads_
- GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
@@ -1393,10 +1389,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- runner_thread_ =
- MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
- runner_thread_->Schedule(
- [this, new_ctx]() { RunnerThread(new_ctx); });
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ [this, new_ctx]() { RunnerThread(new_ctx); }));
}
}
@@ -1650,7 +1645,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled.
bool cancelled_ GUARDED_BY(*mu_) = false;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index e69274e4f2..ebf41925c9 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -181,10 +181,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
- runner_thread_ =
- MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
- runner_thread_->Schedule(
- std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
}
}
@@ -332,7 +331,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_);
- std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
};
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index e9c38eb8a0..754ed772db 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -257,11 +256,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
- prefetch_thread_ =
- MakeUnique<BackgroundWorker>(ctx->env(), "prefetch_thread");
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- prefetch_thread_->Schedule(
- [this, new_ctx]() { PrefetchThread(new_ctx); });
+ prefetch_thread_.reset(ctx->env()->StartThread(
+ {}, "prefetch_thread",
+ [this, new_ctx]() { PrefetchThread(new_ctx); }));
}
return Status::OK();
}
@@ -365,7 +363,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
- std::unique_ptr<BackgroundWorker> prefetch_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
bool prefetch_thread_finished_ GUARDED_BY(mu_) = false;
};
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 7bb2077b62..3f76695bb1 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel {
public:
explicit ToTFRecordOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
- background_worker_(
- ctx->env(),
- strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) {
- }
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())),
+ 1 /* num_threads */, false /* low_latency_hint */)) {}
template <typename T>
Status ParseScalarArgument(OpKernelContext* ctx,
@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
- background_worker_.Schedule([this, ctx, done]() {
+ thread_pool_->Schedule([this, ctx, done]() {
string filename;
OP_REQUIRES_OK_ASYNC(
ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done);
@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel {
}
private:
- BackgroundWorker background_worker_;
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
};
REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),