diff options
Diffstat (limited to 'tensorflow/core/kernels/data/filter_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/filter_dataset_op.cc | 162 |
1 files changed, 98 insertions, 64 deletions
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index be7d182a1f..00884314a9 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -18,11 +18,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -33,84 +31,67 @@ namespace { class FilterDatasetOp : public UnaryDatasetOpKernel { public: - using FilterIteratorPredicate = - std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>; - explicit FilterDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) { + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { + FunctionLibraryRuntime::Handle pred_handle; + OP_REQUIRES_OK(ctx, + ctx->function_library()->Instantiate( + func_.name(), AttrSlice(&func_.attr()), &pred_handle)); + auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() { + OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle)); + }); + + const FunctionBody* pred_body = + ctx->function_library()->GetFunctionBody(pred_handle); + OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1, + errors::InvalidArgument( + "predicate function must have a single return value.")); + Node* ret_node = pred_body->ret_nodes[0]; + Node* ret_input_node; + OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); + std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - std::vector<int> indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - OP_REQUIRES(ctx, indices.size() <= 1, - errors::InvalidArgument( - "predicate function has more than one return value.")); - - FilterIteratorPredicate filter_pred; - if (indices.empty()) { - CapturedFunction* raw_captured_func = captured_func.get(); - filter_pred = [raw_captured_func](IteratorContext* ctx, - const std::vector<Tensor>& args, - bool* out_matched) { - std::vector<Tensor> result; - TF_RETURN_IF_ERROR( - raw_captured_func->RunWithBorrowedArgs(ctx, args, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = result[0].scalar<bool>()(); - return Status::OK(); - }; + if (ret_input_node->def().op() == "_Arg") { + int32 index = -1; + OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index)); + *output = new FilterTensorDataset(ctx, input, func_, + std::move(captured_func), index); } else { - filter_pred = [indices](IteratorContext* ctx, - const std::vector<Tensor>& args, - bool* out_matched) { - const Tensor& predicate = args[indices[0]]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = predicate.scalar<bool>()(); - return Status::OK(); - }; + *output = new FilterFunctionDataset(ctx, input, func_, + std::move(captured_func)); } - - *output = new Dataset(ctx, input, func_, std::move(captured_func), - std::move(filter_pred)); } private: - class Dataset : public DatasetBase { + const int graph_def_version_; + + class FilterDatasetBase : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr<CapturedFunction> captured_func, - FilterIteratorPredicate filter_pred) + FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), - captured_func_(std::move(captured_func)), - filter_pred_(std::move(filter_pred)) { + captured_func_(std::move(captured_func)) { input_->Ref(); } - ~Dataset() override { input_->Unref(); } + ~FilterDatasetBase() override { input_->Unref(); } std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return MakeUnique<Iterator>( - Iterator::Params{this, strings::StrCat(prefix, "::Filter")}, - filter_pred_); + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Filter")})); } const DataTypeVector& output_dtypes() const override { @@ -152,15 +133,17 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + virtual Status EvaluatePredicate(IteratorContext* ctx, + const std::vector<Tensor>& element, + bool* out_matched) const = 0; + private: - class Iterator : public DatasetIterator<Dataset> { + class Iterator : public DatasetIterator<FilterDatasetBase> { public: - explicit Iterator(const Params& params, - FilterIteratorPredicate filter_pred) - : DatasetIterator<Dataset>(params), + explicit Iterator(const Params& params) + : DatasetIterator<FilterDatasetBase>(params), filtered_elements_(0), - dropped_elements_(0), - filter_pred_(std::move(filter_pred)) { + dropped_elements_(0) { std::vector<string> components = str_util::Split(params.prefix, "::", str_util::SkipEmpty()); prefix_end_ = components.back(); @@ -197,7 +180,8 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched)); + TF_RETURN_IF_ERROR( + dataset()->EvaluatePredicate(ctx, *out_tensors, &matched)); if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); @@ -267,14 +251,64 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); int64 filtered_elements_ GUARDED_BY(mu_); int64 dropped_elements_ GUARDED_BY(mu_); - const FilterIteratorPredicate filter_pred_; string prefix_end_; }; const DatasetBase* const input_; const NameAttrList func_; + + protected: const std::unique_ptr<CapturedFunction> captured_func_; - const FilterIteratorPredicate filter_pred_; + }; + + class FilterFunctionDataset : public FilterDatasetBase { + public: + using FilterDatasetBase::FilterDatasetBase; + + protected: + Status EvaluatePredicate(IteratorContext* ctx, + const std::vector<Tensor>& element, + bool* out_matched) const override { + // TODO(mrry): Avoid blocking a threadpool thread. We will need to + // stack-rip the iterators and use async kernels. + std::vector<Tensor> result; + TF_RETURN_IF_ERROR( + captured_func_->RunWithBorrowedArgs(ctx, element, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = result[0].scalar<bool>()(); + return Status::OK(); + } + }; + + class FilterTensorDataset : public FilterDatasetBase { + public: + FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func, + int32 index) + : FilterDatasetBase(ctx, input, func, std::move(captured_func)), + index_(index) {} + + protected: + Status EvaluatePredicate(IteratorContext* ctx, + const std::vector<Tensor>& element, + bool* out_matched) const override { + const Tensor& predicate = element[index_]; + if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = predicate.scalar<bool>()(); + return Status::OK(); + } + + private: + const int32 index_; }; private: |