diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_dataset_op.cc | 62 |
1 files changed, 51 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index f112e1dc43..6b6ffabf4f 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -28,6 +30,9 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: + using MapIteratorFunction = std::function<Status( + IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>; + explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors) { + return raw_captured_func->Run(ctx, std::move(args), out_tensors); + }; + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, std::vector<Tensor> args, + std::vector<Tensor>* out_tensors) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + return Status::OK(); + }; + } + *output = new Dataset(ctx, input, func_, std::move(captured_func), - output_types_, output_shapes_); + output_types_, output_shapes_, std::move(map_func)); } private: @@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes) + const std::vector<PartialTensorShape>& output_shapes, + MapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), - output_shapes_(output_shapes) { + output_shapes_(output_shapes), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Map")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_); } const DataTypeVector& output_dtypes() const override { @@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + explicit Iterator(const Params& params, MapIteratorFunction map_func) + : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - Status s = - dataset()->captured_func_->Run(ctx, std::move(args), out_tensors); + Status s = map_func_(ctx, args, out_tensors); if (errors::IsOutOfRange(s)) { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. @@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: std::unique_ptr<IteratorBase> input_impl_; + const MapIteratorFunction map_func_; }; const DatasetBase* const input_; @@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::unique_ptr<CapturedFunction> captured_func_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; + const MapIteratorFunction map_func_; }; DataTypeVector output_types_; |