aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_and_batch_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc160
1 files changed, 88 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 51a7fd23a8..bf08970560 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
+#include <atomic>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
@@ -26,20 +27,22 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -49,14 +52,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 batch_size;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
OP_REQUIRES(
@@ -77,7 +72,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx,
+ num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
@@ -92,8 +88,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, batch_size, num_parallel_calls,
drop_remainder, output_types_, output_shapes_, func_,
@@ -101,7 +97,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
@@ -110,7 +106,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const Eigen::ThreadPoolDevice* device)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
@@ -147,7 +143,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -190,21 +186,36 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ : DatasetIterator<Dataset>(params),
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ mutex_lock l(*mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
+ }
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -212,25 +223,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
std::shared_ptr<BatchResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
- cond_var_.wait(l);
+ RecordStop(ctx);
+ cond_var_->wait(l);
+ RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
+ cond_var_->notify_all();
}
- cond_var_.notify_all();
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -246,7 +259,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("call_counter"), &call_counter_));
@@ -287,7 +300,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void Callback(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<BatchResult>& result,
const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) {
+ int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
result->UpdateStatus(status);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
@@ -323,18 +336,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void CallCompleted(const std::shared_ptr<BatchResult>& result)
- LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- num_calls_--;
- result->num_calls--;
- }
- cond_var_.notify_all();
+ LOCKS_EXCLUDED(*mu_) {
+ mutex_lock l(*mu_);
+ num_calls_--;
+ result->num_calls--;
+ cond_var_->notify_all();
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
const std::shared_ptr<BatchResult>& result,
- int64 offset) LOCKS_EXCLUDED(mu_) {
+ int64 offset) LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
@@ -363,7 +374,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
- });
+ },
+ prefix());
},
ctx, std::move(input_element)));
}
@@ -390,7 +402,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -420,11 +432,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output_allocated = true;
}
- int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
- dataset()->batch_size_;
- }
-
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
@@ -471,28 +478,36 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
+ RecordStart(ctx.get());
+ auto stop_cleanup =
+ gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
+ int64 max_batch_results =
+ (num_parallel_calls + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ return num_calls_ >= num_parallel_calls ||
+ (batch_results_.size() > max_batch_results ||
+ (batch_results_.size() == max_batch_results &&
+ call_counter_ % dataset()->batch_size_ == 0));
+ };
while (true) {
{
- mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
- batch_results_.size() > MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ == 0))) {
- cond_var_.wait(l);
+ mutex_lock l(*mu_);
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
+ cond_var_->wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- (batch_results_.size() < MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ != 0))) {
+ while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
batch_results_.emplace_back(
new BatchResult(dataset()->batch_size_));
@@ -511,7 +526,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
- size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
@@ -556,7 +571,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadStatus(IteratorStateReader* reader, const string& prefix,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_code")), &code_int));
@@ -574,7 +589,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -615,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
static_cast<int64>(status.code())));
@@ -629,22 +644,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the callback threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the callback threads. In particular, the runner thread should only
- // schedule new calls when the number of in-flight calls is less than the
- // user specified level of parallelism and there are slots available in
- // the `batch_results_` buffer.
- condition_variable cond_var_;
+ // schedule new calls when the number of in-flight calls is less than
+ // `num_parallel_calls_->value` and there are slots available in the
+ // `batch_results_` buffer.
+ const std::shared_ptr<condition_variable> cond_var_;
+ // Identifies the maximum number of parallel calls.
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
- int64 call_counter_ GUARDED_BY(mu_) = 0;
+ int64 call_counter_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
- std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
@@ -659,7 +676,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
- const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
@@ -673,5 +689,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU),
MapAndBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow