aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/filter_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/filter_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc162
1 files changed, 64 insertions, 98 deletions
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index 00884314a9..be7d182a1f 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -18,9 +18,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -31,67 +33,84 @@ namespace {
class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
+ using FilterIteratorPredicate =
+ std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>;
+
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- FunctionLibraryRuntime::Handle pred_handle;
- OP_REQUIRES_OK(ctx,
- ctx->function_library()->Instantiate(
- func_.name(), AttrSlice(&func_.attr()), &pred_handle));
- auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() {
- OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle));
- });
-
- const FunctionBody* pred_body =
- ctx->function_library()->GetFunctionBody(pred_handle);
- OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1,
- errors::InvalidArgument(
- "predicate function must have a single return value."));
- Node* ret_node = pred_body->ret_nodes[0];
- Node* ret_input_node;
- OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
-
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- if (ret_input_node->def().op() == "_Arg") {
- int32 index = -1;
- OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index));
- *output = new FilterTensorDataset(ctx, input, func_,
- std::move(captured_func), index);
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+ OP_REQUIRES(ctx, indices.size() <= 1,
+ errors::InvalidArgument(
+ "predicate function has more than one return value."));
+
+ FilterIteratorPredicate filter_pred;
+ if (indices.empty()) {
+ CapturedFunction* raw_captured_func = captured_func.get();
+ filter_pred = [raw_captured_func](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ std::vector<Tensor> result;
+ TF_RETURN_IF_ERROR(
+ raw_captured_func->RunWithBorrowedArgs(ctx, args, &result));
+
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = result[0].scalar<bool>()();
+ return Status::OK();
+ };
} else {
- *output = new FilterFunctionDataset(ctx, input, func_,
- std::move(captured_func));
+ filter_pred = [indices](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ const Tensor& predicate = args[indices[0]];
+ if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = predicate.scalar<bool>()();
+ return Status::OK();
+ };
}
+
+ *output = new Dataset(ctx, input, func_, std::move(captured_func),
+ std::move(filter_pred));
}
private:
- const int graph_def_version_;
-
- class FilterDatasetBase : public DatasetBase {
+ class Dataset : public DatasetBase {
public:
- FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func,
+ FilterIteratorPredicate filter_pred)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ filter_pred_(std::move(filter_pred)) {
input_->Ref();
}
- ~FilterDatasetBase() override { input_->Unref(); }
+ ~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Filter")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Filter")},
+ filter_pred_);
}
const DataTypeVector& output_dtypes() const override {
@@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- virtual Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const = 0;
-
private:
- class Iterator : public DatasetIterator<FilterDatasetBase> {
+ class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params),
+ explicit Iterator(const Params& params,
+ FilterIteratorPredicate filter_pred)
+ : DatasetIterator<Dataset>(params),
filtered_elements_(0),
- dropped_elements_(0) {
+ dropped_elements_(0),
+ filter_pred_(std::move(filter_pred)) {
std::vector<string> components =
str_util::Split(params.prefix, "::", str_util::SkipEmpty());
prefix_end_ = components.back();
@@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- TF_RETURN_IF_ERROR(
- dataset()->EvaluatePredicate(ctx, *out_tensors, &matched));
+ TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched));
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
@@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
int64 filtered_elements_ GUARDED_BY(mu_);
int64 dropped_elements_ GUARDED_BY(mu_);
+ const FilterIteratorPredicate filter_pred_;
string prefix_end_;
};
const DatasetBase* const input_;
const NameAttrList func_;
-
- protected:
const std::unique_ptr<CapturedFunction> captured_func_;
- };
-
- class FilterFunctionDataset : public FilterDatasetBase {
- public:
- using FilterDatasetBase::FilterDatasetBase;
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- std::vector<Tensor> result;
- TF_RETURN_IF_ERROR(
- captured_func_->RunWithBorrowedArgs(ctx, element, &result));
-
- if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = result[0].scalar<bool>()();
- return Status::OK();
- }
- };
-
- class FilterTensorDataset : public FilterDatasetBase {
- public:
- FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func,
- int32 index)
- : FilterDatasetBase(ctx, input, func, std::move(captured_func)),
- index_(index) {}
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- const Tensor& predicate = element[index_];
- if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = predicate.scalar<bool>()();
- return Status::OK();
- }
-
- private:
- const int32 index_;
+ const FilterIteratorPredicate filter_pred_;
};
private: