From ae0bc6f006497cc04a2ee75166d4ec71c7154fd8 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 5 Oct 2018 13:34:01 -0700 Subject: [tf.data] Adding specialization for `MapDataset`, `ParallelMapDataset`, and `MapAndBatchDataset` whose user-provided functions have the property that each output argument take its value directly from an input argument (e.g. `lambda x, y: y, x`). This specialization can produce the result without having to schedule the function using the executor. PiperOrigin-RevId: 215957592 --- tensorflow/core/kernels/data/filter_dataset_op.cc | 162 +++++++++------------- 1 file changed, 64 insertions(+), 98 deletions(-) (limited to 'tensorflow/core/kernels/data/filter_dataset_op.cc') diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 00884314a9..be7d182a1f 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -18,9 +18,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -31,67 +33,84 @@ namespace { class FilterDatasetOp : public UnaryDatasetOpKernel { public: + using FilterIteratorPredicate = + std::function, bool*)>; + explicit FilterDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - FunctionLibraryRuntime::Handle pred_handle; - OP_REQUIRES_OK(ctx, - ctx->function_library()->Instantiate( - func_.name(), AttrSlice(&func_.attr()), &pred_handle)); - auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() { - OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle)); - }); - - const FunctionBody* pred_body = - ctx->function_library()->GetFunctionBody(pred_handle); - OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1, - errors::InvalidArgument( - "predicate function must have a single return value.")); - Node* ret_node = pred_body->ret_nodes[0]; - Node* ret_input_node; - OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); - std::unique_ptr captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - if (ret_input_node->def().op() == "_Arg") { - int32 index = -1; - OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index)); - *output = new FilterTensorDataset(ctx, input, func_, - std::move(captured_func), index); + std::vector indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + OP_REQUIRES(ctx, indices.size() <= 1, + errors::InvalidArgument( + "predicate function has more than one return value.")); + + FilterIteratorPredicate filter_pred; + if (indices.empty()) { + CapturedFunction* raw_captured_func = captured_func.get(); + filter_pred = [raw_captured_func](IteratorContext* ctx, + const std::vector& args, + bool* out_matched) { + std::vector result; + TF_RETURN_IF_ERROR( + raw_captured_func->RunWithBorrowedArgs(ctx, args, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = result[0].scalar()(); + return Status::OK(); + }; } else { - *output = new FilterFunctionDataset(ctx, input, func_, - std::move(captured_func)); + filter_pred = [indices](IteratorContext* ctx, + const std::vector& args, + bool* out_matched) { + const Tensor& predicate = args[indices[0]]; + if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = predicate.scalar()(); + return Status::OK(); + }; } + + *output = new Dataset(ctx, input, func_, std::move(captured_func), + std::move(filter_pred)); } private: - const int graph_def_version_; - - class FilterDatasetBase : public DatasetBase { + class Dataset : public DatasetBase { public: - FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr captured_func) + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr captured_func, + FilterIteratorPredicate filter_pred) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + filter_pred_(std::move(filter_pred)) { input_->Ref(); } - ~FilterDatasetBase() override { input_->Unref(); } + ~Dataset() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr( - new Iterator({this, strings::StrCat(prefix, "::Filter")})); + return MakeUnique( + Iterator::Params{this, strings::StrCat(prefix, "::Filter")}, + filter_pred_); } const DataTypeVector& output_dtypes() const override { @@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - virtual Status EvaluatePredicate(IteratorContext* ctx, - const std::vector& element, - bool* out_matched) const = 0; - private: - class Iterator : public DatasetIterator { + class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params) - : DatasetIterator(params), + explicit Iterator(const Params& params, + FilterIteratorPredicate filter_pred) + : DatasetIterator(params), filtered_elements_(0), - dropped_elements_(0) { + dropped_elements_(0), + filter_pred_(std::move(filter_pred)) { std::vector components = str_util::Split(params.prefix, "::", str_util::SkipEmpty()); prefix_end_ = components.back(); @@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - TF_RETURN_IF_ERROR( - dataset()->EvaluatePredicate(ctx, *out_tensors, &matched)); + TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched)); if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); @@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr input_impl_ GUARDED_BY(mu_); int64 filtered_elements_ GUARDED_BY(mu_); int64 dropped_elements_ GUARDED_BY(mu_); + const FilterIteratorPredicate filter_pred_; string prefix_end_; }; const DatasetBase* const input_; const NameAttrList func_; - - protected: const std::unique_ptr captured_func_; - }; - - class FilterFunctionDataset : public FilterDatasetBase { - public: - using FilterDatasetBase::FilterDatasetBase; - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector& element, - bool* out_matched) const override { - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - std::vector result; - TF_RETURN_IF_ERROR( - captured_func_->RunWithBorrowedArgs(ctx, element, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = result[0].scalar()(); - return Status::OK(); - } - }; - - class FilterTensorDataset : public FilterDatasetBase { - public: - FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr captured_func, - int32 index) - : FilterDatasetBase(ctx, input, func, std::move(captured_func)), - index_(index) {} - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector& element, - bool* out_matched) const override { - const Tensor& predicate = element[index_]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = predicate.scalar()(); - return Status::OK(); - } - - private: - const int32 index_; + const FilterIteratorPredicate filter_pred_; }; private: -- cgit v1.2.3