aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2017-06-09 14:24:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-09 14:31:24 -0700
commit7ce6e4f871b0767547a9a5cfb9d19dba79704489 (patch)
tree576dd376ac9d6ea556b8a3321e103fd5128acba1
parent961c5a6c0da901748af17b816545195653e7a228 (diff)
Remove some boilerplate when implementing Datasets
This change adds 2 classes: - DatasetOpKernel: An abstract class that simplifies the creation of a Dataset object by automatically handling the creation of an associated resource handle. It also includes a helper to parse Scalar arguments. - UnaryDatasetOpKernel: An abstract class that simplifies parsing the input dataset for datasets that are one-to-one transformations of existing datasets (e.g. take, drop, shuffle, map, filter, etc.) Additionally, this change goes through the existing datasets and converts them to use the new helper classes. A subsequent change will add support for parsing Lists of tensor arguments. PiperOrigin-RevId: 158562508
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/batch_dataset_op.cc32
-rw-r--r--tensorflow/core/kernels/dataset.cc42
-rw-r--r--tensorflow/core/kernels/dataset.h37
-rw-r--r--tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc32
-rw-r--r--tensorflow/core/kernels/filter_dataset_op.cc21
-rw-r--r--tensorflow/core/kernels/flat_map_dataset_op.cc23
-rw-r--r--tensorflow/core/kernels/group_by_window_dataset_op.cc33
-rw-r--r--tensorflow/core/kernels/map_dataset_op.cc23
-rw-r--r--tensorflow/core/kernels/padded_batch_dataset_op.cc34
-rw-r--r--tensorflow/core/kernels/range_dataset_op.cc41
-rw-r--r--tensorflow/core/kernels/repeat_dataset_op.cc28
-rw-r--r--tensorflow/core/kernels/shuffle_dataset_op.cc51
-rw-r--r--tensorflow/core/kernels/skip_dataset_op.cc32
-rw-r--r--tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/take_dataset_op.cc35
-rw-r--r--tensorflow/core/kernels/tensor_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/tensor_slice_dataset_op.cc15
-rw-r--r--tensorflow/core/kernels/zip_dataset_op.cc18
19 files changed, 222 insertions, 304 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 214897b7fa..491302116b 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5163,6 +5163,7 @@ tf_mkl_kernel_library(
cc_library(
name = "dataset",
+ srcs = ["dataset.cc"],
hdrs = ["dataset.h"],
deps = [
"//tensorflow/core:framework",
diff --git a/tensorflow/core/kernels/batch_dataset_op.cc b/tensorflow/core/kernels/batch_dataset_op.cc
index c8289eff2a..67eff44a5d 100644
--- a/tensorflow/core/kernels/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/batch_dataset_op.cc
@@ -24,33 +24,21 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class BatchDatasetOp : public OpKernel {
+class BatchDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit BatchDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- // Create a new BatchDatasetOp::Dataset, insert it in the step-local
- // container, and return it as the output.
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
- const Tensor* batch_size_t;
- OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(batch_size_t->shape()),
- errors::InvalidArgument("batch_size must be a scalar"));
- const int64 batch_size = batch_size_t->flat<int64>()(0);
+ explicit BatchDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 batch_size;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "batch_size", &batch_size));
OP_REQUIRES(
ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
- DatasetBase* dataset = new Dataset(batch_size, input);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(batch_size, input);
}
private:
diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc
new file mode 100644
index 0000000000..925cbda56e
--- /dev/null
+++ b/tensorflow/core/kernels/dataset.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/dataset.h"
+
+namespace tensorflow {
+
+void DatasetOpKernel::Compute(OpKernelContext* ctx) {
+ DatasetBase* dataset = nullptr;
+ MakeDataset(ctx, &dataset);
+ if (ctx->status().ok()) {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ ResourceHandle handle = MakeResourceHandle<DatasetBase>(
+ ctx, ctx->step_container()->name(), name());
+ OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
+ output->flat<ResourceHandle>()(0) = handle;
+ }
+}
+
+void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
+ DatasetBase** output) {
+ DatasetBase* input;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
+ core::ScopedUnref unref_input(input);
+
+ MakeDataset(ctx, input, output);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index 83ffabe224..da56844dbe 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -162,6 +162,43 @@ class DatasetIterator : public IteratorBase {
// shared dataset resource.
};
+// Encapsulates the work required to plug a DatasetBase into the core TensorFlow
+// graph execution engine.
+class DatasetOpKernel : public OpKernel {
+ public:
+ DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) final;
+
+ protected:
+ // Subclasses should implement this method. It will be called during Compute
+ // execution.
+ virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+};
+
+// Encapsulates the work required to plug unary Datasets into the core
+// TensorFlow graph execution engine.
+class UnaryDatasetOpKernel : public DatasetOpKernel {
+ public:
+ UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
+ virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) = 0;
+};
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_
diff --git a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
index 2c36093355..50b8a05492 100644
--- a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
@@ -24,28 +24,23 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class DenseToSparseBatchDatasetOp : public OpKernel {
+class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit DenseToSparseBatchDatasetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
+ : UnaryDatasetOpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
// Create a new DenseToSparseBatchDatasetOp::Dataset, insert it in the
// step-local container, and return it as the output.
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
OP_REQUIRES(
ctx, input->output_dtypes().size() == 1,
errors::InvalidArgument("DenseToSparseBatchDataset only supports "
"inputs with a single component."));
- const Tensor* batch_size_t;
- OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(batch_size_t->shape()),
- errors::InvalidArgument("batch_size must be a scalar"));
- const int64 batch_size = batch_size_t->flat<int64>()(0);
+ int64 batch_size;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "batch_size", &batch_size));
OP_REQUIRES(
ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
@@ -59,11 +54,11 @@ class DenseToSparseBatchDatasetOp : public OpKernel {
row_shape.AddDim(row_shape_t->vec<int64>()(i));
}
- DatasetBase* dataset = nullptr;
+ *output = nullptr;
#define HANDLE_TYPE(DT) \
if (input->output_dtypes()[0] == DT) { \
- dataset = \
+ *output = \
new Dataset<EnumToDataType<DT>::Type>(batch_size, row_shape, input); \
}
HANDLE_TYPE(DT_FLOAT);
@@ -85,16 +80,9 @@ class DenseToSparseBatchDatasetOp : public OpKernel {
HANDLE_TYPE(DT_QUINT16);
#undef HANDLE_TYPE
OP_REQUIRES(
- ctx, dataset != nullptr,
+ ctx, *output != nullptr,
errors::Unimplemented("DenseToSparseBatchDataset unhandled data type: ",
input->output_dtypes()[0]));
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
}
private:
diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc
index 62ad921062..3503c45f9a 100644
--- a/tensorflow/core/kernels/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/filter_dataset_op.cc
@@ -28,18 +28,16 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class FilterDatasetOp : public OpKernel {
+class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
- void Compute(OpKernelContext* ctx) override {
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
std::vector<Tensor> other_arguments;
@@ -53,14 +51,7 @@ class FilterDatasetOp : public OpKernel {
std::move(other_arguments),
&captured_func));
- DatasetBase* dataset = new Dataset(input, std::move(captured_func));
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(input, std::move(captured_func));
}
private:
diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc
index 68a6cf1960..eb55c01b12 100644
--- a/tensorflow/core/kernels/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/flat_map_dataset_op.cc
@@ -28,20 +28,18 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class FlatMapDatasetOp : public OpKernel {
+class FlatMapDatasetOp : public UnaryDatasetOpKernel {
public:
explicit FlatMapDatasetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
- void Compute(OpKernelContext* ctx) override {
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
std::vector<Tensor> other_arguments;
@@ -55,15 +53,8 @@ class FlatMapDatasetOp : public OpKernel {
std::move(other_arguments),
&captured_func));
- DatasetBase* dataset = new Dataset(input, std::move(captured_func),
- output_types_, output_shapes_);
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(input, std::move(captured_func), output_types_,
+ output_shapes_);
}
private:
diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc
index a58c15a097..948e83390e 100644
--- a/tensorflow/core/kernels/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc
@@ -29,26 +29,22 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class GroupByWindowDatasetOp : public OpKernel {
+class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
- void Compute(OpKernelContext* ctx) override {
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
- const Tensor* window_size_t;
- OP_REQUIRES_OK(ctx, ctx->input("window_size", &window_size_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(window_size_t->shape()),
- errors::InvalidArgument("window_size must be a scalar"));
- const int64 window_size = window_size_t->flat<int64>()(0);
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 window_size;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
@@ -84,16 +80,9 @@ class GroupByWindowDatasetOp : public OpKernel {
std::move(reduce_func_other_arguments),
&captured_reduce_func));
- DatasetBase* dataset = new Dataset(
- input, window_size, std::move(captured_key_func),
- std::move(captured_reduce_func), output_types_, output_shapes_);
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(input, window_size, std::move(captured_key_func),
+ std::move(captured_reduce_func), output_types_,
+ output_shapes_);
}
private:
diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc
index 08308d8557..f755097324 100644
--- a/tensorflow/core/kernels/map_dataset_op.cc
+++ b/tensorflow/core/kernels/map_dataset_op.cc
@@ -28,20 +28,18 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class MapDatasetOp : public OpKernel {
+class MapDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapDatasetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
- void Compute(OpKernelContext* ctx) override {
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
std::vector<Tensor> other_arguments;
@@ -55,15 +53,8 @@ class MapDatasetOp : public OpKernel {
std::move(other_arguments),
&captured_func));
- DatasetBase* dataset = new Dataset(input, std::move(captured_func),
- output_types_, output_shapes_);
-
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(input, std::move(captured_func), output_types_,
+ output_shapes_);
}
private:
diff --git a/tensorflow/core/kernels/padded_batch_dataset_op.cc b/tensorflow/core/kernels/padded_batch_dataset_op.cc
index b0c000dd25..bb1eda101f 100644
--- a/tensorflow/core/kernels/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/padded_batch_dataset_op.cc
@@ -122,22 +122,16 @@ Status SetElementZero(Tensor* element, const Tensor& padding) {
element->dtype());
}
-class PaddedBatchDatasetOp : public OpKernel {
+class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- // Create a new BatchDatasetOp::Dataset, insert it in the step-local
- // container, and return it as the output.
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
- const Tensor* batch_size_t;
- OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(batch_size_t->shape()),
- errors::InvalidArgument("batch_size must be a scalar"));
- const int64 batch_size = batch_size_t->flat<int64>()(0);
+ explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 batch_size;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "batch_size", &batch_size));
OP_REQUIRES(
ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
@@ -188,14 +182,8 @@ class PaddedBatchDatasetOp : public OpKernel {
padding_values.push_back(tensor::DeepCopy(padding_value_t));
}
- DatasetBase* dataset = new Dataset(batch_size, std::move(padded_shapes),
- std::move(padding_values), input);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(batch_size, std::move(padded_shapes),
+ std::move(padding_values), input);
}
private:
diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc
index c181f6e804..8220cb9379 100644
--- a/tensorflow/core/kernels/range_dataset_op.cc
+++ b/tensorflow/core/kernels/range_dataset_op.cc
@@ -24,38 +24,23 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class RangeDatasetOp : public OpKernel {
+class RangeDatasetOp : public DatasetOpKernel {
public:
- explicit RangeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor* start_t;
- OP_REQUIRES_OK(ctx, ctx->input("start", &start_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_t->shape()),
- errors::InvalidArgument("start must be a scalar"));
- const int64 start = start_t->flat<int64>()(0);
-
- const Tensor* stop_t;
- OP_REQUIRES_OK(ctx, ctx->input("stop", &stop_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_t->shape()),
- errors::InvalidArgument("stop must be a scalar"));
- const int64 stop = stop_t->flat<int64>()(0);
-
- const Tensor* step_t;
- OP_REQUIRES_OK(ctx, ctx->input("step", &step_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(step_t->shape()),
- errors::InvalidArgument("step must be a scalar"));
- const int64 step = step_t->flat<int64>()(0);
+ explicit RangeDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ int64 start;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "start", &start));
+
+ int64 stop;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stop", &stop));
+
+ int64 step;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "step", &step));
OP_REQUIRES(ctx, step != 0,
errors::InvalidArgument("step must be a non-zero integer."));
- DatasetBase* dataset = new Dataset(start, stop, step);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(start, stop, step);
}
private:
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index 8fc59e1779..02103400b1 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -24,28 +24,20 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class RepeatDatasetOp : public OpKernel {
+class RepeatDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit RepeatDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ explicit RepeatDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
// Create a new RepeatDatasetOp::Dataset, insert it in the step-local
// container, and return it as the output.
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
- const Tensor* count_t;
- OP_REQUIRES_OK(ctx, ctx->input("count", &count_t));
- const int64 count = count_t->flat<int64>()(0);
-
- DatasetBase* dataset = new Dataset(count, input);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ int64 count;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
+
+ *output = new Dataset(count, input);
}
private:
diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc
index 14e7d1bf97..86786287bb 100644
--- a/tensorflow/core/kernels/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/shuffle_dataset_op.cc
@@ -27,45 +27,28 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class ShuffleDatasetOp : public OpKernel {
+class ShuffleDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit ShuffleDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- // Create a new ShuffleDatasetOp::Dataset, insert it in the step-local
- // container, and return it as the output.
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
- const Tensor* buffer_size_t;
- OP_REQUIRES_OK(ctx, ctx->input("buffer_size", &buffer_size_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(buffer_size_t->shape()),
- errors::InvalidArgument("buffer_size must be a scalar"));
- const int64 buffer_size = buffer_size_t->flat<int64>()(0);
+ explicit ShuffleDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ // Create a new ShuffleDatasetOp::Dataset, and return it as the output.
+ int64 buffer_size;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(
ctx, buffer_size > 0,
errors::InvalidArgument("buffer_size must be greater than zero."));
- const Tensor* seed_t;
- OP_REQUIRES_OK(ctx, ctx->input("seed", &seed_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(seed_t->shape()),
- errors::InvalidArgument("seed must be a scalar"));
- const int64 seed = seed_t->flat<int64>()(0);
-
- const Tensor* seed2_t;
- OP_REQUIRES_OK(ctx, ctx->input("seed2", &seed2_t));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(seed2_t->shape()),
- errors::InvalidArgument("seed2 must be a scalar"));
- const int64 seed2 = seed2_t->flat<int64>()(0);
-
- DatasetBase* dataset = new Dataset(input, buffer_size, seed, seed2);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ int64 seed;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
+
+ int64 seed2;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
+
+ *output = new Dataset(input, buffer_size, seed, seed2);
}
private:
diff --git a/tensorflow/core/kernels/skip_dataset_op.cc b/tensorflow/core/kernels/skip_dataset_op.cc
index 1cff90a05e..ea749c7365 100644
--- a/tensorflow/core/kernels/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/skip_dataset_op.cc
@@ -24,28 +24,18 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class SkipDatasetOp : public OpKernel {
+class SkipDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit SkipDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- // Create a new RepeatDatasetOp::Dataset, insert it in the step-local
- // container, and return it as the output.
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
- const Tensor* count_t;
- OP_REQUIRES_OK(ctx, ctx->input("count", &count_t));
- const int64 count = count_t->flat<int64>()(0);
-
- DatasetBase* dataset = new Dataset(count, input);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ explicit SkipDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ // Create a new RepeatDatasetOp::Dataset, and return it as the output.
+ int64 count;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
+
+ *output = new Dataset(count, input);
}
private:
diff --git a/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc
index 70cab66d64..679b39bef9 100644
--- a/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc
@@ -148,12 +148,12 @@ class Dataset : public DatasetBase {
};
template <typename T>
-class SparseTensorSliceDatasetOp : public OpKernel {
+class SparseTensorSliceDatasetOp : public DatasetOpKernel {
public:
explicit SparseTensorSliceDatasetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
+ : DatasetOpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
// Create a new SparseTensorSliceDatasetOp::Dataset, insert it in
// the step container, and return it as the output.
const Tensor* indices;
@@ -196,13 +196,7 @@ class SparseTensorSliceDatasetOp : public OpKernel {
sparse::SparseTensor sparse_tensor(
*indices, *values, TensorShape(dense_shape->vec<int64>()), std_order);
- DatasetBase* dataset = new Dataset<T>(sparse_tensor);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset<T>(sparse_tensor);
}
private:
diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc
index e27a36bc9b..068ec4b9e3 100644
--- a/tensorflow/core/kernels/take_dataset_op.cc
+++ b/tensorflow/core/kernels/take_dataset_op.cc
@@ -24,28 +24,18 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class TakeDatasetOp : public OpKernel {
+class TakeDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit TakeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- // Create a new RepeatDatasetOp::Dataset, insert it in the step-local
- // container, and return it as the output.
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
- core::ScopedUnref unref_input(input);
-
- const Tensor* count_t;
- OP_REQUIRES_OK(ctx, ctx->input("count", &count_t));
- const int64 count = count_t->flat<int64>()(0);
-
- DatasetBase* dataset = new Dataset(count, input);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ explicit TakeDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ // Create a new TakeDatasetOp::Dataset, and return it as the output.
+ int64 count;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
+ *output = new Dataset(count, input);
}
private:
@@ -124,8 +114,7 @@ class TakeDatasetOp : public OpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU),
- TakeDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
} // namespace
diff --git a/tensorflow/core/kernels/tensor_dataset_op.cc b/tensorflow/core/kernels/tensor_dataset_op.cc
index 6b6fcb1978..5674af787b 100644
--- a/tensorflow/core/kernels/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/tensor_dataset_op.cc
@@ -24,11 +24,11 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class TensorDatasetOp : public OpKernel {
+class TensorDatasetOp : public DatasetOpKernel {
public:
- explicit TensorDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
// Create a new TensorDatasetOp::Dataset, insert it in the step
// container, and return it as the output.
OpInputList inputs;
@@ -40,13 +40,7 @@ class TensorDatasetOp : public OpKernel {
for (const Tensor& t : inputs) {
components.push_back(t);
}
- DatasetBase* dataset = new Dataset(std::move(components));
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(std::move(components));
}
private:
diff --git a/tensorflow/core/kernels/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/tensor_slice_dataset_op.cc
index fc70d2ecc5..69dd1584b6 100644
--- a/tensorflow/core/kernels/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/tensor_slice_dataset_op.cc
@@ -24,11 +24,12 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class TensorSliceDatasetOp : public OpKernel {
+class TensorSliceDatasetOp : public DatasetOpKernel {
public:
- explicit TensorSliceDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ explicit TensorSliceDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
// Create a new TensorDatasetOp::Dataset, insert it in the step
// container, and return it as the output.
OpInputList inputs;
@@ -49,13 +50,7 @@ class TensorSliceDatasetOp : public OpKernel {
errors::InvalidArgument(
"All components must have the same size in the 0th dimension"));
}
- DatasetBase* dataset = new Dataset(std::move(components));
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
- output->flat<ResourceHandle>()(0) = handle;
+ *output = new Dataset(std::move(components));
}
private:
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc
index e7fc9bc6b1..325e8a5df9 100644
--- a/tensorflow/core/kernels/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/zip_dataset_op.cc
@@ -24,11 +24,11 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class ZipDatasetOp : public OpKernel {
+class ZipDatasetOp : public DatasetOpKernel {
public:
- explicit ZipDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ explicit ZipDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
std::vector<DatasetBase*> inputs;
Status s;
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
@@ -43,17 +43,7 @@ class ZipDatasetOp : public OpKernel {
}
if (s.ok()) {
- DatasetBase* dataset = new Dataset(inputs);
- Tensor* output = nullptr;
- s = ctx->allocate_output(0, TensorShape({}), &output);
- if (s.ok()) {
- ResourceHandle handle = MakeResourceHandle<DatasetBase>(
- ctx, ctx->step_container()->name(), name());
- s = CreateResource(ctx, handle, dataset);
- if (s.ok()) {
- output->flat<ResourceHandle>()(0) = handle;
- }
- }
+ *output = new Dataset(inputs);
}
// TODO(mrry): Implement a container that acts as a