diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_dataset_op.cc | 56 |
1 files changed, 11 insertions, 45 deletions
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 0abb2eb4f3..f112e1dc43 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -17,9 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -30,9 +28,6 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: - using MapIteratorFunction = std::function<Status( - IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>; - explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -48,36 +43,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); - std::vector<int> indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - - MapIteratorFunction map_func; - if (indices.empty()) { - CapturedFunction* raw_captured_func = captured_func.get(); - map_func = [raw_captured_func](IteratorContext* ctx, - std::vector<Tensor> args, - std::vector<Tensor>* out_tensors) { - return raw_captured_func->Run(ctx, std::move(args), out_tensors); - }; - } else { - std::vector<bool> can_move = ComputeMoveVector(indices); - map_func = [indices, can_move](IteratorContext* ctx, - std::vector<Tensor> args, - std::vector<Tensor>* out_tensors) { - std::map<int, int> counts; - for (size_t i = 0; i < indices.size(); ++i) { - if (can_move[i]) { - out_tensors->push_back(std::move(args[indices[i]])); - } else { - out_tensors->push_back(args[indices[i]]); - } - } - return Status::OK(); - }; - } - *output = new Dataset(ctx, input, func_, std::move(captured_func), - output_types_, output_shapes_, std::move(map_func)); + output_types_, output_shapes_); } private: @@ -87,15 +54,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes, - MapIteratorFunction map_func) + const std::vector<PartialTensorShape>& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), - output_shapes_(output_shapes), - map_func_(std::move(map_func)) { + output_shapes_(output_shapes) { input_->Ref(); } @@ -103,8 +68,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return MakeUnique<Iterator>( - Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_); + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Map")})); } const DataTypeVector& output_dtypes() const override { @@ -151,8 +116,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params, MapIteratorFunction map_func) - : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {} + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -174,7 +139,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - Status s = map_func_(ctx, args, out_tensors); + // TODO(mrry): Avoid blocking a threadpool thread. We will need to + // stack-rip the iterators and use async kernels. + Status s = + dataset()->captured_func_->Run(ctx, std::move(args), out_tensors); if (errors::IsOutOfRange(s)) { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. @@ -199,7 +167,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: std::unique_ptr<IteratorBase> input_impl_; - const MapIteratorFunction map_func_; }; const DatasetBase* const input_; @@ -207,7 +174,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::unique_ptr<CapturedFunction> captured_func_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; - const MapIteratorFunction map_func_; }; DataTypeVector output_types_; |