diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_and_batch_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_and_batch_dataset_op.cc | 180 |
1 files changed, 72 insertions, 108 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index f9aaa3080e..bf08970560 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -30,7 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -43,10 +41,6 @@ namespace { // transformation more robust. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: - using MapAndBatchIteratorFunction = - std::function<void(IteratorContext*, const string&, std::vector<Tensor>, - std::shared_ptr<std::vector<Tensor>>, StatusCallback)>; - explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { @@ -97,66 +91,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - std::vector<int> indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - - MapAndBatchIteratorFunction map_func; - if (indices.empty()) { - CapturedFunction* raw_captured_func = captured_func.get(); - map_func = [raw_captured_func]( - IteratorContext* ctx, const string& prefix, - std::vector<Tensor> args, - std::shared_ptr<std::vector<Tensor>> out_tensors, - StatusCallback done) { - raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(), - std::move(done), prefix); - }; - } else { - std::vector<bool> can_move = ComputeMoveVector(indices); - map_func = [indices, can_move]( - IteratorContext* ctx, const string& prefix, - std::vector<Tensor> args, - std::shared_ptr<std::vector<Tensor>> out_tensors, - StatusCallback done) { - for (size_t i = 0; i < indices.size(); ++i) { - if (can_move[i]) { - out_tensors->push_back(std::move(args[indices[i]])); - } else { - out_tensors->push_back(args[indices[i]]); - } - } - done(Status::OK()); - }; - } - - *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls, - drop_remainder, output_types_, output_shapes_, - std::move(captured_func), &ctx->eigen_cpu_device(), - std::move(map_func)); + *output = new Dataset(ctx, input, batch_size, num_parallel_calls, + drop_remainder, output_types_, output_shapes_, func_, + std::move(captured_func), &ctx->eigen_cpu_device()); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, int64 batch_size, + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, + const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, - const Eigen::ThreadPoolDevice* device, - MapAndBatchIteratorFunction map_func) + const Eigen::ThreadPoolDevice* device) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), + map_fn_(func), captured_func_(std::move(captured_func)), - device_(device), - map_func_(std::move(map_func)) { + device_(device) { input_->Ref(); } @@ -164,9 +123,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return MakeUnique<Iterator>( - Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")}, - map_func_); + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); } const DataTypeVector& output_dtypes() const override { @@ -185,7 +143,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -207,7 +165,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments_types.emplace_back(t.dtype()); } AttrValue f; - b->BuildAttrValue(func_, &f); + b->BuildAttrValue(map_fn_, &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -227,14 +185,12 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params, - MapAndBatchIteratorFunction map_func) + explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), mu_(std::make_shared<mutex>()), cond_var_(std::make_shared<condition_variable>()), num_parallel_calls_(std::make_shared<model::SharedState>( - params.dataset->num_parallel_calls_, mu_, cond_var_)), - map_func_(std::move(map_func)) {} + params.dataset->num_parallel_calls_, mu_, cond_var_)) {} ~Iterator() override { mutex_lock l(*mu_); @@ -341,6 +297,44 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls; // access guarded by owner's mutex }; + void Callback(const std::shared_ptr<IteratorContext>& ctx, + const std::shared_ptr<BatchResult>& result, + const std::shared_ptr<std::vector<Tensor>>& return_values, + int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) { + result->UpdateStatus(status); + if (status.ok()) { + EnsureOutputAllocated(ctx, result, return_values); + for (size_t i = 0; i < return_values->size(); ++i) { + const Tensor& tensor = return_values->at(i); + Tensor* batch = &(result->output)[i]; + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + TensorShape batch_shape = batch->shape(); + batch_shape.RemoveDim(0); + result->UpdateStatus(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does not " + "match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString())); + break; + } + // TODO(mrry): Add a version of DoParallelConcat that allows us to + // move `tensor` where possible, to speed up string tensor batching. + Status copy_status = ::tensorflow::functor::DoParallelConcat( + *dataset()->device_, tensor, offset, batch); + if (!copy_status.ok()) { + result->UpdateStatus(copy_status); + break; + } + } + { + mutex_lock l(result->mu); + result->num_elements++; + } + } + CallCompleted(result); + } + void CallCompleted(const std::shared_ptr<BatchResult>& result) LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); @@ -369,48 +363,21 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return; } - std::shared_ptr<std::vector<Tensor>> return_values = - std::make_shared<std::vector<Tensor>>(); - auto done = [this, ctx, result, return_values, offset](Status status) { - result->UpdateStatus(status); - if (status.ok()) { - EnsureOutputAllocated(ctx, result, return_values); - for (size_t i = 0; i < return_values->size(); ++i) { - const Tensor& tensor = return_values->at(i); - Tensor* batch = &(result->output)[i]; - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - TensorShape batch_shape = batch->shape(); - batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does " - "not match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); - break; - } - // TODO(mrry): Add a version of DoParallelConcat that allows us to - // move `tensor` where possible, to speed up string tensor - // batching. - Status copy_status = ::tensorflow::functor::DoParallelConcat( - *dataset()->device_, tensor, offset, batch); - if (!copy_status.ok()) { - result->UpdateStatus(copy_status); - break; - } - } - { - mutex_lock l(result->mu); - result->num_elements++; - } - } - CallCompleted(result); - }; - - // Apply the map function on `input_element`, storing the result in - // `return_values`, and invoking `done` when finished. - map_func_(ctx.get(), prefix(), std::move(input_element), - std::move(return_values), std::move(done)); + // Call `captured_func_(input_element)`, using `Callback` to store the + // result in `result`. + (*ctx->runner())(std::bind( + [this, result, offset](std::shared_ptr<IteratorContext> ctx, + std::vector<Tensor> input_element) { + std::shared_ptr<std::vector<Tensor>> return_values( + new std::vector<Tensor>()); + dataset()->captured_func_->RunAsync( + ctx.get(), std::move(input_element), return_values.get(), + [this, ctx, result, return_values, offset](Status status) { + Callback(ctx, result, return_values, offset, status); + }, + prefix()); + }, + ctx, std::move(input_element))); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -437,7 +404,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - auto ctx_copy = std::make_shared<IteratorContext>(*ctx); + std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&Iterator::RunnerThread, this, ctx_copy))); @@ -542,8 +509,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { - batch_results_.push_back( - std::make_shared<BatchResult>(dataset()->batch_size_)); + batch_results_.emplace_back( + new BatchResult(dataset()->batch_size_)); } int64 offset = call_counter_++ % dataset()->batch_size_; new_calls.emplace_back(batch_results_.back(), offset); @@ -560,8 +527,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - batch_results_.push_back( - std::make_shared<BatchResult>(dataset()->batch_size_)); + batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); std::shared_ptr<BatchResult> result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -687,8 +653,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. const std::shared_ptr<model::SharedState> num_parallel_calls_; - const MapAndBatchIteratorFunction map_func_; - // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. @@ -707,9 +671,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const bool drop_remainder_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; + const NameAttrList map_fn_; const std::unique_ptr<CapturedFunction> captured_func_; const Eigen::ThreadPoolDevice* device_; // not owned - const MapAndBatchIteratorFunction map_func_; }; const int op_version_; |