aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/filter_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/filter_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc162
1 files changed, 98 insertions, 64 deletions
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index be7d182a1f..00884314a9 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -18,11 +18,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -33,84 +31,67 @@ namespace {
class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
- using FilterIteratorPredicate =
- std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>;
-
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
+ FunctionLibraryRuntime::Handle pred_handle;
+ OP_REQUIRES_OK(ctx,
+ ctx->function_library()->Instantiate(
+ func_.name(), AttrSlice(&func_.attr()), &pred_handle));
+ auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() {
+ OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle));
+ });
+
+ const FunctionBody* pred_body =
+ ctx->function_library()->GetFunctionBody(pred_handle);
+ OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1,
+ errors::InvalidArgument(
+ "predicate function must have a single return value."));
+ Node* ret_node = pred_body->ret_nodes[0];
+ Node* ret_input_node;
+ OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- std::vector<int> indices;
- OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
- OP_REQUIRES(ctx, indices.size() <= 1,
- errors::InvalidArgument(
- "predicate function has more than one return value."));
-
- FilterIteratorPredicate filter_pred;
- if (indices.empty()) {
- CapturedFunction* raw_captured_func = captured_func.get();
- filter_pred = [raw_captured_func](IteratorContext* ctx,
- const std::vector<Tensor>& args,
- bool* out_matched) {
- std::vector<Tensor> result;
- TF_RETURN_IF_ERROR(
- raw_captured_func->RunWithBorrowedArgs(ctx, args, &result));
-
- if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = result[0].scalar<bool>()();
- return Status::OK();
- };
+ if (ret_input_node->def().op() == "_Arg") {
+ int32 index = -1;
+ OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index));
+ *output = new FilterTensorDataset(ctx, input, func_,
+ std::move(captured_func), index);
} else {
- filter_pred = [indices](IteratorContext* ctx,
- const std::vector<Tensor>& args,
- bool* out_matched) {
- const Tensor& predicate = args[indices[0]];
- if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = predicate.scalar<bool>()();
- return Status::OK();
- };
+ *output = new FilterFunctionDataset(ctx, input, func_,
+ std::move(captured_func));
}
-
- *output = new Dataset(ctx, input, func_, std::move(captured_func),
- std::move(filter_pred));
}
private:
- class Dataset : public DatasetBase {
+ const int graph_def_version_;
+
+ class FilterDatasetBase : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func,
- FilterIteratorPredicate filter_pred)
+ FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
- captured_func_(std::move(captured_func)),
- filter_pred_(std::move(filter_pred)) {
+ captured_func_(std::move(captured_func)) {
input_->Ref();
}
- ~Dataset() override { input_->Unref(); }
+ ~FilterDatasetBase() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return MakeUnique<Iterator>(
- Iterator::Params{this, strings::StrCat(prefix, "::Filter")},
- filter_pred_);
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Filter")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -152,15 +133,17 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ virtual Status EvaluatePredicate(IteratorContext* ctx,
+ const std::vector<Tensor>& element,
+ bool* out_matched) const = 0;
+
private:
- class Iterator : public DatasetIterator<Dataset> {
+ class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
- explicit Iterator(const Params& params,
- FilterIteratorPredicate filter_pred)
- : DatasetIterator<Dataset>(params),
+ explicit Iterator(const Params& params)
+ : DatasetIterator<FilterDatasetBase>(params),
filtered_elements_(0),
- dropped_elements_(0),
- filter_pred_(std::move(filter_pred)) {
+ dropped_elements_(0) {
std::vector<string> components =
str_util::Split(params.prefix, "::", str_util::SkipEmpty());
prefix_end_ = components.back();
@@ -197,7 +180,8 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched));
+ TF_RETURN_IF_ERROR(
+ dataset()->EvaluatePredicate(ctx, *out_tensors, &matched));
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
@@ -267,14 +251,64 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
int64 filtered_elements_ GUARDED_BY(mu_);
int64 dropped_elements_ GUARDED_BY(mu_);
- const FilterIteratorPredicate filter_pred_;
string prefix_end_;
};
const DatasetBase* const input_;
const NameAttrList func_;
+
+ protected:
const std::unique_ptr<CapturedFunction> captured_func_;
- const FilterIteratorPredicate filter_pred_;
+ };
+
+ class FilterFunctionDataset : public FilterDatasetBase {
+ public:
+ using FilterDatasetBase::FilterDatasetBase;
+
+ protected:
+ Status EvaluatePredicate(IteratorContext* ctx,
+ const std::vector<Tensor>& element,
+ bool* out_matched) const override {
+ // TODO(mrry): Avoid blocking a threadpool thread. We will need to
+ // stack-rip the iterators and use async kernels.
+ std::vector<Tensor> result;
+ TF_RETURN_IF_ERROR(
+ captured_func_->RunWithBorrowedArgs(ctx, element, &result));
+
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = result[0].scalar<bool>()();
+ return Status::OK();
+ }
+ };
+
+ class FilterTensorDataset : public FilterDatasetBase {
+ public:
+ FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func,
+ int32 index)
+ : FilterDatasetBase(ctx, input, func, std::move(captured_func)),
+ index_(index) {}
+
+ protected:
+ Status EvaluatePredicate(IteratorContext* ctx,
+ const std::vector<Tensor>& element,
+ bool* out_matched) const override {
+ const Tensor& predicate = element[index_];
+ if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = predicate.scalar<bool>()();
+ return Status::OK();
+ }
+
+ private:
+ const int32 index_;
};
private: