aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-27 16:05:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 16:09:18 -0700
commitece50dd9992ac17e3094c7f6d1914febd7a036b5 (patch)
tree18de739f4a7e33abbc9631b46b3992ac53ff446b /tensorflow/core/kernels
parentb8c86c3bbd8271ed968087f24e7fb704103bc733 (diff)
[tf.data Introducing tf.data.Dataset.reduce() which reduces elements of a (finite) dataset to a single element.
PiperOrigin-RevId: 214852364
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc111
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc4
4 files changed, 114 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index d6ee42a7c6..e7244ee208 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -30,8 +30,7 @@ namespace {
class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
@@ -421,7 +420,6 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 8b417bb1c2..14aefe5d54 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -31,8 +31,7 @@ namespace {
class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_));
@@ -507,7 +506,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index c0bc507ec0..7a833668ac 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -659,6 +659,115 @@ class ToSingleElementOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
+class ReduceDatasetOp : public AsyncOpKernel {
+ public:
+ explicit ReduceDatasetOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(
+ ctx->env(),
+ strings::StrCat("reduce_thread_", SanitizeThreadSuffix(name()))) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule([this, ctx, done]() {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
+ OpInputList inputs;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
+ done);
+ std::vector<Tensor> state(inputs.begin(), inputs.end());
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ CapturedFunction::Create(reduce_func_, ctx, "other_arguments",
+ use_inter_op_parallelism_, &captured_func),
+ done);
+
+ IteratorContext iter_ctx(ctx);
+ OP_REQUIRES_OK_ASYNC(ctx, captured_func->Instantiate(&iter_ctx), done);
+
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
+ done);
+
+ // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
+ // avoid destruction races.
+ IteratorBase* raw_iterator = iterator.release();
+ auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
+ delete raw_iterator;
+ done();
+ });
+
+ // Iterate through the input dataset.
+ Status status;
+ while (true) {
+ std::vector<Tensor> next_input_element;
+ bool end_of_input;
+ status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
+ &end_of_input);
+ if (!status.ok() || end_of_input) {
+ break;
+ }
+
+ // Run the reduce function to update the current state.
+ std::vector<Tensor> args;
+ args.reserve(state.size() + next_input_element.size());
+ std::copy(state.begin(), state.end(), std::back_inserter(args));
+ std::copy(next_input_element.begin(), next_input_element.end(),
+ std::back_inserter(args));
+
+ std::vector<Tensor> reduce_func_output;
+ status =
+ captured_func->Run(&iter_ctx, std::move(args), &reduce_func_output);
+ if (!status.ok()) {
+ break;
+ }
+ std::swap(reduce_func_output, state);
+ }
+
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ return;
+ }
+ for (int i = 0; i < state.size(); ++i) {
+ OP_REQUIRES_ASYNC(
+ ctx, state[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The result does not match the expected type for component ", i,
+ ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(state[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
+ errors::InvalidArgument(
+ "The result does not match the expected shape for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", state[i].shape().DebugString(), "."),
+ done);
+ ctx->set_output(i, state[i]);
+ }
+ });
+ }
+
+ private:
+ NameAttrList reduce_func_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
+ BackgroundWorker background_worker_;
+};
+
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -1146,6 +1255,8 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
+REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
+ ReduceDatasetOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index dbe31f37b8..2a911aa368 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -32,8 +32,7 @@ namespace {
class ScanDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ScanDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tstate", &state_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -258,7 +257,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector state_types_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;