diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_and_batch_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_and_batch_dataset_op.cc | 180 |
1 files changed, 108 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index bf08970560..f9aaa3080e 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -41,6 +43,10 @@ namespace { // transformation more robust. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: + using MapAndBatchIteratorFunction = + std::function<void(IteratorContext*, const string&, std::vector<Tensor>, + std::shared_ptr<std::vector<Tensor>>, StatusCallback)>; + explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { @@ -91,31 +97,66 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - *output = new Dataset(ctx, input, batch_size, num_parallel_calls, - drop_remainder, output_types_, output_shapes_, func_, - std::move(captured_func), &ctx->eigen_cpu_device()); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapAndBatchIteratorFunction map_func; + if (indices.empty()) { + CapturedFunction* raw_captured_func = captured_func.get(); + map_func = [raw_captured_func]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::shared_ptr<std::vector<Tensor>> out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(), + std::move(done), prefix); + }; + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::shared_ptr<std::vector<Tensor>> out_tensors, + StatusCallback done) { + for (size_t i = 0; i < indices.size(); ++i) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } + done(Status::OK()); + }; + } + + *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls, + drop_remainder, output_types_, output_shapes_, + std::move(captured_func), &ctx->eigen_cpu_device(), + std::move(map_func)); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, - const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, - const Eigen::ThreadPoolDevice* device) + const Eigen::ThreadPoolDevice* device, + MapAndBatchIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), + func_(func), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), - map_fn_(func), captured_func_(std::move(captured_func)), - device_(device) { + device_(device), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -123,8 +164,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")}, + map_func_); } const DataTypeVector& output_dtypes() const override { @@ -143,7 +185,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -165,7 +207,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments_types.emplace_back(t.dtype()); } AttrValue f; - b->BuildAttrValue(map_fn_, &f); + b->BuildAttrValue(func_, &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -185,12 +227,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) + explicit Iterator(const Params& params, + MapAndBatchIteratorFunction map_func) : DatasetIterator<Dataset>(params), mu_(std::make_shared<mutex>()), cond_var_(std::make_shared<condition_variable>()), num_parallel_calls_(std::make_shared<model::SharedState>( - params.dataset->num_parallel_calls_, mu_, cond_var_)) {} + params.dataset->num_parallel_calls_, mu_, cond_var_)), + map_func_(std::move(map_func)) {} ~Iterator() override { mutex_lock l(*mu_); @@ -297,44 +341,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls; // access guarded by owner's mutex }; - void Callback(const std::shared_ptr<IteratorContext>& ctx, - const std::shared_ptr<BatchResult>& result, - const std::shared_ptr<std::vector<Tensor>>& return_values, - int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) { - result->UpdateStatus(status); - if (status.ok()) { - EnsureOutputAllocated(ctx, result, return_values); - for (size_t i = 0; i < return_values->size(); ++i) { - const Tensor& tensor = return_values->at(i); - Tensor* batch = &(result->output)[i]; - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - TensorShape batch_shape = batch->shape(); - batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does not " - "match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); - break; - } - // TODO(mrry): Add a version of DoParallelConcat that allows us to - // move `tensor` where possible, to speed up string tensor batching. - Status copy_status = ::tensorflow::functor::DoParallelConcat( - *dataset()->device_, tensor, offset, batch); - if (!copy_status.ok()) { - result->UpdateStatus(copy_status); - break; - } - } - { - mutex_lock l(result->mu); - result->num_elements++; - } - } - CallCompleted(result); - } - void CallCompleted(const std::shared_ptr<BatchResult>& result) LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); @@ -363,21 +369,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return; } - // Call `captured_func_(input_element)`, using `Callback` to store the - // result in `result`. - (*ctx->runner())(std::bind( - [this, result, offset](std::shared_ptr<IteratorContext> ctx, - std::vector<Tensor> input_element) { - std::shared_ptr<std::vector<Tensor>> return_values( - new std::vector<Tensor>()); - dataset()->captured_func_->RunAsync( - ctx.get(), std::move(input_element), return_values.get(), - [this, ctx, result, return_values, offset](Status status) { - Callback(ctx, result, return_values, offset, status); - }, - prefix()); - }, - ctx, std::move(input_element))); + std::shared_ptr<std::vector<Tensor>> return_values = + std::make_shared<std::vector<Tensor>>(); + auto done = [this, ctx, result, return_values, offset](Status status) { + result->UpdateStatus(status); + if (status.ok()) { + EnsureOutputAllocated(ctx, result, return_values); + for (size_t i = 0; i < return_values->size(); ++i) { + const Tensor& tensor = return_values->at(i); + Tensor* batch = &(result->output)[i]; + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + TensorShape batch_shape = batch->shape(); + batch_shape.RemoveDim(0); + result->UpdateStatus(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does " + "not match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString())); + break; + } + // TODO(mrry): Add a version of DoParallelConcat that allows us to + // move `tensor` where possible, to speed up string tensor + // batching. + Status copy_status = ::tensorflow::functor::DoParallelConcat( + *dataset()->device_, tensor, offset, batch); + if (!copy_status.ok()) { + result->UpdateStatus(copy_status); + break; + } + } + { + mutex_lock l(result->mu); + result->num_elements++; + } + } + CallCompleted(result); + }; + + // Apply the map function on `input_element`, storing the result in + // `return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + std::move(return_values), std::move(done)); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -404,7 +437,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); + auto ctx_copy = std::make_shared<IteratorContext>(*ctx); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&Iterator::RunnerThread, this, ctx_copy))); @@ -509,8 +542,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { - batch_results_.emplace_back( - new BatchResult(dataset()->batch_size_)); + batch_results_.push_back( + std::make_shared<BatchResult>(dataset()->batch_size_)); } int64 offset = call_counter_++ % dataset()->batch_size_; new_calls.emplace_back(batch_results_.back(), offset); @@ -527,7 +560,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); + batch_results_.push_back( + std::make_shared<BatchResult>(dataset()->batch_size_)); std::shared_ptr<BatchResult> result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -653,6 +687,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. const std::shared_ptr<model::SharedState> num_parallel_calls_; + const MapAndBatchIteratorFunction map_func_; + // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. @@ -671,9 +707,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const bool drop_remainder_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; - const NameAttrList map_fn_; const std::unique_ptr<CapturedFunction> captured_func_; const Eigen::ThreadPoolDevice* device_; // not owned + const MapAndBatchIteratorFunction map_func_; }; const int op_version_; |