aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc56
1 files changed, 11 insertions, 45 deletions
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 0abb2eb4f3..f112e1dc43 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -17,9 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -30,9 +28,6 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
- using MapIteratorFunction = std::function<Status(
- IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>;
-
explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -48,36 +43,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
- std::vector<int> indices;
- OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
-
- MapIteratorFunction map_func;
- if (indices.empty()) {
- CapturedFunction* raw_captured_func = captured_func.get();
- map_func = [raw_captured_func](IteratorContext* ctx,
- std::vector<Tensor> args,
- std::vector<Tensor>* out_tensors) {
- return raw_captured_func->Run(ctx, std::move(args), out_tensors);
- };
- } else {
- std::vector<bool> can_move = ComputeMoveVector(indices);
- map_func = [indices, can_move](IteratorContext* ctx,
- std::vector<Tensor> args,
- std::vector<Tensor>* out_tensors) {
- std::map<int, int> counts;
- for (size_t i = 0; i < indices.size(); ++i) {
- if (can_move[i]) {
- out_tensors->push_back(std::move(args[indices[i]]));
- } else {
- out_tensors->push_back(args[indices[i]]);
- }
- }
- return Status::OK();
- };
- }
-
*output = new Dataset(ctx, input, func_, std::move(captured_func),
- output_types_, output_shapes_, std::move(map_func));
+ output_types_, output_shapes_);
}
private:
@@ -87,15 +54,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes,
- MapIteratorFunction map_func)
+ const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
output_types_(output_types),
- output_shapes_(output_shapes),
- map_func_(std::move(map_func)) {
+ output_shapes_(output_shapes) {
input_->Ref();
}
@@ -103,8 +68,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return MakeUnique<Iterator>(
- Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_);
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Map")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -151,8 +116,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params, MapIteratorFunction map_func)
- : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {}
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -174,7 +139,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- Status s = map_func_(ctx, args, out_tensors);
+ // TODO(mrry): Avoid blocking a threadpool thread. We will need to
+ // stack-rip the iterators and use async kernels.
+ Status s =
+ dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
@@ -199,7 +167,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
std::unique_ptr<IteratorBase> input_impl_;
- const MapIteratorFunction map_func_;
};
const DatasetBase* const input_;
@@ -207,7 +174,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
- const MapIteratorFunction map_func_;
};
DataTypeVector output_types_;