aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc62
1 files changed, 51 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index f112e1dc43..6b6ffabf4f 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -28,6 +30,9 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapIteratorFunction = std::function<Status(
+ IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>;
+
explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ return raw_captured_func->Run(ctx, std::move(args), out_tensors);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ return Status::OK();
+ };
+ }
+
*output = new Dataset(ctx, input, func_, std::move(captured_func),
- output_types_, output_shapes_);
+ output_types_, output_shapes_, std::move(map_func));
}
private:
@@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
+ const std::vector<PartialTensorShape>& output_shapes,
+ MapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
output_types_(output_types),
- output_shapes_(output_shapes) {
+ output_shapes_(output_shapes),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Map")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ explicit Iterator(const Params& params, MapIteratorFunction map_func)
+ : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- Status s =
- dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
+ Status s = map_func_(ctx, args, out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
@@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
std::unique_ptr<IteratorBase> input_impl_;
+ const MapIteratorFunction map_func_;
};
const DatasetBase* const input_;
@@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const MapIteratorFunction map_func_;
};
DataTypeVector output_types_;