diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_and_batch_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_and_batch_dataset_op.cc | 160 |
1 files changed, 88 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 51a7fd23a8..bf08970560 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS +#include <atomic> #include <utility> #include "tensorflow/core/common_runtime/function.h" @@ -26,20 +27,22 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. +// TODO(b/116852688): Make coordination between the performance model and this +// transformation more robust. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -49,14 +52,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int64 batch_size; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); OP_REQUIRES( @@ -77,7 +72,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { case 2: OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, + num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); break; @@ -92,8 +88,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, batch_size, num_parallel_calls, drop_remainder, output_types_, output_shapes_, func_, @@ -101,7 +97,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, @@ -110,7 +106,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, const Eigen::ThreadPoolDevice* device) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), input_(input), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), @@ -147,7 +143,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -190,21 +186,36 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + : DatasetIterator<Dataset>(params), + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + params.dataset->num_parallel_calls_, mu_, cond_var_)) {} ~Iterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + mutex_lock l(*mu_); + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + port::NumSchedulableCPUs()); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); + } + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); } Status GetNextInternal(IteratorContext* ctx, @@ -212,25 +223,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { bool* end_of_sequence) override { std::shared_ptr<BatchResult> result; { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (batch_results_.empty() || batch_results_.front()->num_calls > 0) { - cond_var_.wait(l); + RecordStop(ctx); + cond_var_->wait(l); + RecordStart(ctx); } std::swap(result, batch_results_.front()); batch_results_.pop_front(); + cond_var_->notify_all(); } - cond_var_.notify_all(); return ProcessResult(ctx, result, out_tensors, end_of_sequence); } protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -246,7 +259,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR( reader->ReadScalar(full_name("call_counter"), &call_counter_)); @@ -287,7 +300,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void Callback(const std::shared_ptr<IteratorContext>& ctx, const std::shared_ptr<BatchResult>& result, const std::shared_ptr<std::vector<Tensor>>& return_values, - int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) { + int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) { result->UpdateStatus(status); if (status.ok()) { EnsureOutputAllocated(ctx, result, return_values); @@ -323,18 +336,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void CallCompleted(const std::shared_ptr<BatchResult>& result) - LOCKS_EXCLUDED(mu_) { - { - mutex_lock l(mu_); - num_calls_--; - result->num_calls--; - } - cond_var_.notify_all(); + LOCKS_EXCLUDED(*mu_) { + mutex_lock l(*mu_); + num_calls_--; + result->num_calls--; + cond_var_->notify_all(); } void CallFunction(std::shared_ptr<IteratorContext> ctx, const std::shared_ptr<BatchResult>& result, - int64 offset) LOCKS_EXCLUDED(mu_) { + int64 offset) LOCKS_EXCLUDED(*mu_) { // Get the next input element. std::vector<Tensor> input_element; bool end_of_input; @@ -363,7 +374,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ctx.get(), std::move(input_element), return_values.get(), [this, ctx, result, return_values, offset](Status status) { Callback(ctx, result, return_values, offset, status); - }); + }, + prefix()); }, ctx, std::move(input_element))); } @@ -390,7 +402,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( @@ -420,11 +432,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { result->output_allocated = true; } - int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) / - dataset()->batch_size_; - } - Status ProcessResult(IteratorContext* ctx, const std::shared_ptr<BatchResult>& result, std::vector<Tensor>* out_tensors, @@ -471,28 +478,36 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; - new_calls.reserve(dataset()->num_parallel_calls_); + RecordStart(ctx.get()); + auto stop_cleanup = + gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); }); + new_calls.reserve(num_parallel_calls_->value); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_->value; + int64 max_batch_results = + (num_parallel_calls + dataset()->batch_size_ - 1) / + dataset()->batch_size_; + return num_calls_ >= num_parallel_calls || + (batch_results_.size() > max_batch_results || + (batch_results_.size() == max_batch_results && + call_counter_ % dataset()->batch_size_ == 0)); + }; while (true) { { - mutex_lock l(mu_); - while (!cancelled_ && - (num_calls_ >= dataset()->num_parallel_calls_ || - batch_results_.size() > MaxBatchResults() || - (batch_results_.size() == MaxBatchResults() && - call_counter_ % dataset()->batch_size_ == 0))) { - cond_var_.wait(l); + mutex_lock l(*mu_); + while (!cancelled_ && busy()) { + RecordStop(ctx.get()); + cond_var_->wait(l); + RecordStart(ctx.get()); } if (cancelled_) { return; } - while (num_calls_ < dataset()->num_parallel_calls_ && - (batch_results_.size() < MaxBatchResults() || - (batch_results_.size() == MaxBatchResults() && - call_counter_ % dataset()->batch_size_ != 0))) { + while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { batch_results_.emplace_back( new BatchResult(dataset()->batch_size_)); @@ -511,7 +526,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, - size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); std::shared_ptr<BatchResult> result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); @@ -556,7 +571,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status ReadStatus(IteratorStateReader* reader, const string& prefix, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar( full_name(strings::StrCat(prefix, "_code")), &code_int)); @@ -574,7 +589,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status WriteBatchResult(IteratorStateWriter* writer, size_t index) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { std::shared_ptr<BatchResult> result = batch_results_[index]; string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -615,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status WriteStatus(IteratorStateWriter* writer, const string& prefix, - const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")), static_cast<int64>(status.code()))); @@ -629,22 +644,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // Used for coordination between the main thread, the runner thread, and // the callback threads. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread, the runner thread, and // the callback threads. In particular, the runner thread should only - // schedule new calls when the number of in-flight calls is less than the - // user specified level of parallelism and there are slots available in - // the `batch_results_` buffer. - condition_variable cond_var_; + // schedule new calls when the number of in-flight calls is less than + // `num_parallel_calls_->value` and there are slots available in the + // `batch_results_` buffer. + const std::shared_ptr<condition_variable> cond_var_; + // Identifies the maximum number of parallel calls. + const std::shared_ptr<model::SharedState> num_parallel_calls_; // Counts the number of outstanding calls for this batch. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. - int64 call_counter_ GUARDED_BY(mu_) = 0; + int64 call_counter_ GUARDED_BY(*mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; // Buffer for storing the (intermediate) batch results. - std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_) = false; + std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + bool cancelled_ GUARDED_BY(*mu_) = false; }; const DatasetBase* const input_; @@ -659,7 +676,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const Eigen::ThreadPoolDevice* device_; // not owned }; - const int graph_def_version_; const int op_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; @@ -673,5 +689,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU), MapAndBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow |