diff options
author | Brennan Saeta <saeta@google.com> | 2017-06-09 14:24:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-09 14:31:24 -0700 |
commit | 7ce6e4f871b0767547a9a5cfb9d19dba79704489 (patch) | |
tree | 576dd376ac9d6ea556b8a3321e103fd5128acba1 | |
parent | 961c5a6c0da901748af17b816545195653e7a228 (diff) |
Remove some boilerplate when implementing Datasets
This change adds 2 classes:
- DatasetOpKernel: An abstract class that simplifies the creation of a
Dataset object by automatically handling the creation of an associated
resource handle. It also includes a helper to parse Scalar arguments.
- UnaryDatasetOpKernel: An abstract class that simplifies parsing the input
dataset for datasets that are one-to-one transformations of existing
datasets (e.g. take, drop, shuffle, map, filter, etc.)
Additionally, this change goes through the existing datasets and converts them
to use the new helper classes.
A subsequent change will add support for parsing Lists of tensor arguments.
PiperOrigin-RevId: 158562508
19 files changed, 222 insertions, 304 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 214897b7fa..491302116b 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5163,6 +5163,7 @@ tf_mkl_kernel_library( cc_library( name = "dataset", + srcs = ["dataset.cc"], hdrs = ["dataset.h"], deps = [ "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/batch_dataset_op.cc b/tensorflow/core/kernels/batch_dataset_op.cc index c8289eff2a..67eff44a5d 100644 --- a/tensorflow/core/kernels/batch_dataset_op.cc +++ b/tensorflow/core/kernels/batch_dataset_op.cc @@ -24,33 +24,21 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class BatchDatasetOp : public OpKernel { +class BatchDatasetOp : public UnaryDatasetOpKernel { public: - explicit BatchDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - // Create a new BatchDatasetOp::Dataset, insert it in the step-local - // container, and return it as the output. - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - - const Tensor* batch_size_t; - OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(batch_size_t->shape()), - errors::InvalidArgument("batch_size must be a scalar")); - const int64 batch_size = batch_size_t->flat<int64>()(0); + explicit BatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 batch_size; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "batch_size", &batch_size)); OP_REQUIRES( ctx, batch_size > 0, errors::InvalidArgument("Batch size must be greater than zero.")); - DatasetBase* dataset = new Dataset(batch_size, input); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(batch_size, input); } private: diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc new file mode 100644 index 0000000000..925cbda56e --- /dev/null +++ b/tensorflow/core/kernels/dataset.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/dataset.h" + +namespace tensorflow { + +void DatasetOpKernel::Compute(OpKernelContext* ctx) { + DatasetBase* dataset = nullptr; + MakeDataset(ctx, &dataset); + if (ctx->status().ok()) { + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + ResourceHandle handle = MakeResourceHandle<DatasetBase>( + ctx, ctx->step_container()->name(), name()); + OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); + output->flat<ResourceHandle>()(0) = handle; + } +} + +void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, + DatasetBase** output) { + DatasetBase* input; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); + core::ScopedUnref unref_input(input); + + MakeDataset(ctx, input, output); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index 83ffabe224..da56844dbe 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -162,6 +162,43 @@ class DatasetIterator : public IteratorBase { // shared dataset resource. }; +// Encapsulates the work required to plug a DatasetBase into the core TensorFlow +// graph execution engine. +class DatasetOpKernel : public OpKernel { + public: + DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) final; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; + + template <typename T> + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar<T>()(); + return Status::OK(); + } +}; + +// Encapsulates the work required to plug unary Datasets into the core +// TensorFlow graph execution engine. +class UnaryDatasetOpKernel : public DatasetOpKernel { + public: + UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) = 0; +}; + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_ diff --git a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc index 2c36093355..50b8a05492 100644 --- a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc @@ -24,28 +24,23 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class DenseToSparseBatchDatasetOp : public OpKernel { +class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit DenseToSparseBatchDatasetOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} + : UnaryDatasetOpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { // Create a new DenseToSparseBatchDatasetOp::Dataset, insert it in the // step-local container, and return it as the output. - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - OP_REQUIRES( ctx, input->output_dtypes().size() == 1, errors::InvalidArgument("DenseToSparseBatchDataset only supports " "inputs with a single component.")); - const Tensor* batch_size_t; - OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(batch_size_t->shape()), - errors::InvalidArgument("batch_size must be a scalar")); - const int64 batch_size = batch_size_t->flat<int64>()(0); + int64 batch_size; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "batch_size", &batch_size)); OP_REQUIRES( ctx, batch_size > 0, errors::InvalidArgument("Batch size must be greater than zero.")); @@ -59,11 +54,11 @@ class DenseToSparseBatchDatasetOp : public OpKernel { row_shape.AddDim(row_shape_t->vec<int64>()(i)); } - DatasetBase* dataset = nullptr; + *output = nullptr; #define HANDLE_TYPE(DT) \ if (input->output_dtypes()[0] == DT) { \ - dataset = \ + *output = \ new Dataset<EnumToDataType<DT>::Type>(batch_size, row_shape, input); \ } HANDLE_TYPE(DT_FLOAT); @@ -85,16 +80,9 @@ class DenseToSparseBatchDatasetOp : public OpKernel { HANDLE_TYPE(DT_QUINT16); #undef HANDLE_TYPE OP_REQUIRES( - ctx, dataset != nullptr, + ctx, *output != nullptr, errors::Unimplemented("DenseToSparseBatchDataset unhandled data type: ", input->output_dtypes()[0])); - - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; } private: diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc index 62ad921062..3503c45f9a 100644 --- a/tensorflow/core/kernels/filter_dataset_op.cc +++ b/tensorflow/core/kernels/filter_dataset_op.cc @@ -28,18 +28,16 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class FilterDatasetOp : public OpKernel { +class FilterDatasetOp : public UnaryDatasetOpKernel { public: explicit FilterDatasetOp(OpKernelConstruction* ctx) - : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); } - void Compute(OpKernelContext* ctx) override { - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); std::vector<Tensor> other_arguments; @@ -53,14 +51,7 @@ class FilterDatasetOp : public OpKernel { std::move(other_arguments), &captured_func)); - DatasetBase* dataset = new Dataset(input, std::move(captured_func)); - - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(input, std::move(captured_func)); } private: diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc index 68a6cf1960..eb55c01b12 100644 --- a/tensorflow/core/kernels/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/flat_map_dataset_op.cc @@ -28,20 +28,18 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class FlatMapDatasetOp : public OpKernel { +class FlatMapDatasetOp : public UnaryDatasetOpKernel { public: explicit FlatMapDatasetOp(OpKernelConstruction* ctx) - : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } - void Compute(OpKernelContext* ctx) override { - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); std::vector<Tensor> other_arguments; @@ -55,15 +53,8 @@ class FlatMapDatasetOp : public OpKernel { std::move(other_arguments), &captured_func)); - DatasetBase* dataset = new Dataset(input, std::move(captured_func), - output_types_, output_shapes_); - - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(input, std::move(captured_func), output_types_, + output_shapes_); } private: diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc index a58c15a097..948e83390e 100644 --- a/tensorflow/core/kernels/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc @@ -29,26 +29,22 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class GroupByWindowDatasetOp : public OpKernel { +class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { public: explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx) - : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } - void Compute(OpKernelContext* ctx) override { - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - - const Tensor* window_size_t; - OP_REQUIRES_OK(ctx, ctx->input("window_size", &window_size_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(window_size_t->shape()), - errors::InvalidArgument("window_size must be a scalar")); - const int64 window_size = window_size_t->flat<int64>()(0); + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 window_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size)); OP_REQUIRES( ctx, window_size > 0, errors::InvalidArgument("Window size must be greater than zero.")); @@ -84,16 +80,9 @@ class GroupByWindowDatasetOp : public OpKernel { std::move(reduce_func_other_arguments), &captured_reduce_func)); - DatasetBase* dataset = new Dataset( - input, window_size, std::move(captured_key_func), - std::move(captured_reduce_func), output_types_, output_shapes_); - - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(input, window_size, std::move(captured_key_func), + std::move(captured_reduce_func), output_types_, + output_shapes_); } private: diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc index 08308d8557..f755097324 100644 --- a/tensorflow/core/kernels/map_dataset_op.cc +++ b/tensorflow/core/kernels/map_dataset_op.cc @@ -28,20 +28,18 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class MapDatasetOp : public OpKernel { +class MapDatasetOp : public UnaryDatasetOpKernel { public: explicit MapDatasetOp(OpKernelConstruction* ctx) - : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } - void Compute(OpKernelContext* ctx) override { - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); std::vector<Tensor> other_arguments; @@ -55,15 +53,8 @@ class MapDatasetOp : public OpKernel { std::move(other_arguments), &captured_func)); - DatasetBase* dataset = new Dataset(input, std::move(captured_func), - output_types_, output_shapes_); - - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(input, std::move(captured_func), output_types_, + output_shapes_); } private: diff --git a/tensorflow/core/kernels/padded_batch_dataset_op.cc b/tensorflow/core/kernels/padded_batch_dataset_op.cc index b0c000dd25..bb1eda101f 100644 --- a/tensorflow/core/kernels/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/padded_batch_dataset_op.cc @@ -122,22 +122,16 @@ Status SetElementZero(Tensor* element, const Tensor& padding) { element->dtype()); } -class PaddedBatchDatasetOp : public OpKernel { +class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { public: - explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - // Create a new BatchDatasetOp::Dataset, insert it in the step-local - // container, and return it as the output. - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - - const Tensor* batch_size_t; - OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(batch_size_t->shape()), - errors::InvalidArgument("batch_size must be a scalar")); - const int64 batch_size = batch_size_t->flat<int64>()(0); + explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 batch_size; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "batch_size", &batch_size)); OP_REQUIRES( ctx, batch_size > 0, errors::InvalidArgument("Batch size must be greater than zero.")); @@ -188,14 +182,8 @@ class PaddedBatchDatasetOp : public OpKernel { padding_values.push_back(tensor::DeepCopy(padding_value_t)); } - DatasetBase* dataset = new Dataset(batch_size, std::move(padded_shapes), - std::move(padding_values), input); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(batch_size, std::move(padded_shapes), + std::move(padding_values), input); } private: diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc index c181f6e804..8220cb9379 100644 --- a/tensorflow/core/kernels/range_dataset_op.cc +++ b/tensorflow/core/kernels/range_dataset_op.cc @@ -24,38 +24,23 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class RangeDatasetOp : public OpKernel { +class RangeDatasetOp : public DatasetOpKernel { public: - explicit RangeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - const Tensor* start_t; - OP_REQUIRES_OK(ctx, ctx->input("start", &start_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_t->shape()), - errors::InvalidArgument("start must be a scalar")); - const int64 start = start_t->flat<int64>()(0); - - const Tensor* stop_t; - OP_REQUIRES_OK(ctx, ctx->input("stop", &stop_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_t->shape()), - errors::InvalidArgument("stop must be a scalar")); - const int64 stop = stop_t->flat<int64>()(0); - - const Tensor* step_t; - OP_REQUIRES_OK(ctx, ctx->input("step", &step_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(step_t->shape()), - errors::InvalidArgument("step must be a scalar")); - const int64 step = step_t->flat<int64>()(0); + explicit RangeDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + int64 start; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "start", &start)); + + int64 stop; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stop", &stop)); + + int64 step; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "step", &step)); OP_REQUIRES(ctx, step != 0, errors::InvalidArgument("step must be a non-zero integer.")); - DatasetBase* dataset = new Dataset(start, stop, step); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(start, stop, step); } private: diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc index 8fc59e1779..02103400b1 100644 --- a/tensorflow/core/kernels/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/repeat_dataset_op.cc @@ -24,28 +24,20 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class RepeatDatasetOp : public OpKernel { +class RepeatDatasetOp : public UnaryDatasetOpKernel { public: - explicit RepeatDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit RepeatDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { // Create a new RepeatDatasetOp::Dataset, insert it in the step-local // container, and return it as the output. - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - - const Tensor* count_t; - OP_REQUIRES_OK(ctx, ctx->input("count", &count_t)); - const int64 count = count_t->flat<int64>()(0); - - DatasetBase* dataset = new Dataset(count, input); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + int64 count; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count)); + + *output = new Dataset(count, input); } private: diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc index 14e7d1bf97..86786287bb 100644 --- a/tensorflow/core/kernels/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/shuffle_dataset_op.cc @@ -27,45 +27,28 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class ShuffleDatasetOp : public OpKernel { +class ShuffleDatasetOp : public UnaryDatasetOpKernel { public: - explicit ShuffleDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - // Create a new ShuffleDatasetOp::Dataset, insert it in the step-local - // container, and return it as the output. - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - - const Tensor* buffer_size_t; - OP_REQUIRES_OK(ctx, ctx->input("buffer_size", &buffer_size_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(buffer_size_t->shape()), - errors::InvalidArgument("buffer_size must be a scalar")); - const int64 buffer_size = buffer_size_t->flat<int64>()(0); + explicit ShuffleDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + // Create a new ShuffleDatasetOp::Dataset, and return it as the output. + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); OP_REQUIRES( ctx, buffer_size > 0, errors::InvalidArgument("buffer_size must be greater than zero.")); - const Tensor* seed_t; - OP_REQUIRES_OK(ctx, ctx->input("seed", &seed_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(seed_t->shape()), - errors::InvalidArgument("seed must be a scalar")); - const int64 seed = seed_t->flat<int64>()(0); - - const Tensor* seed2_t; - OP_REQUIRES_OK(ctx, ctx->input("seed2", &seed2_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(seed2_t->shape()), - errors::InvalidArgument("seed2 must be a scalar")); - const int64 seed2 = seed2_t->flat<int64>()(0); - - DatasetBase* dataset = new Dataset(input, buffer_size, seed, seed2); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + int64 seed; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed)); + + int64 seed2; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2)); + + *output = new Dataset(input, buffer_size, seed, seed2); } private: diff --git a/tensorflow/core/kernels/skip_dataset_op.cc b/tensorflow/core/kernels/skip_dataset_op.cc index 1cff90a05e..ea749c7365 100644 --- a/tensorflow/core/kernels/skip_dataset_op.cc +++ b/tensorflow/core/kernels/skip_dataset_op.cc @@ -24,28 +24,18 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class SkipDatasetOp : public OpKernel { +class SkipDatasetOp : public UnaryDatasetOpKernel { public: - explicit SkipDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - // Create a new RepeatDatasetOp::Dataset, insert it in the step-local - // container, and return it as the output. - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - - const Tensor* count_t; - OP_REQUIRES_OK(ctx, ctx->input("count", &count_t)); - const int64 count = count_t->flat<int64>()(0); - - DatasetBase* dataset = new Dataset(count, input); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + explicit SkipDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + // Create a new RepeatDatasetOp::Dataset, and return it as the output. + int64 count; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count)); + + *output = new Dataset(count, input); } private: diff --git a/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc index 70cab66d64..679b39bef9 100644 --- a/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc @@ -148,12 +148,12 @@ class Dataset : public DatasetBase { }; template <typename T> -class SparseTensorSliceDatasetOp : public OpKernel { +class SparseTensorSliceDatasetOp : public DatasetOpKernel { public: explicit SparseTensorSliceDatasetOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} + : DatasetOpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { // Create a new SparseTensorSliceDatasetOp::Dataset, insert it in // the step container, and return it as the output. const Tensor* indices; @@ -196,13 +196,7 @@ class SparseTensorSliceDatasetOp : public OpKernel { sparse::SparseTensor sparse_tensor( *indices, *values, TensorShape(dense_shape->vec<int64>()), std_order); - DatasetBase* dataset = new Dataset<T>(sparse_tensor); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset<T>(sparse_tensor); } private: diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc index e27a36bc9b..068ec4b9e3 100644 --- a/tensorflow/core/kernels/take_dataset_op.cc +++ b/tensorflow/core/kernels/take_dataset_op.cc @@ -24,28 +24,18 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class TakeDatasetOp : public OpKernel { +class TakeDatasetOp : public UnaryDatasetOpKernel { public: - explicit TakeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - // Create a new RepeatDatasetOp::Dataset, insert it in the step-local - // container, and return it as the output. - DatasetBase* input; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input)); - core::ScopedUnref unref_input(input); - - const Tensor* count_t; - OP_REQUIRES_OK(ctx, ctx->input("count", &count_t)); - const int64 count = count_t->flat<int64>()(0); - - DatasetBase* dataset = new Dataset(count, input); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + explicit TakeDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + // Create a new TakeDatasetOp::Dataset, and return it as the output. + int64 count; + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count)); + *output = new Dataset(count, input); } private: @@ -124,8 +114,7 @@ class TakeDatasetOp : public OpKernel { }; }; -REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), - TakeDatasetOp); +REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp); } // namespace diff --git a/tensorflow/core/kernels/tensor_dataset_op.cc b/tensorflow/core/kernels/tensor_dataset_op.cc index 6b6fcb1978..5674af787b 100644 --- a/tensorflow/core/kernels/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_dataset_op.cc @@ -24,11 +24,11 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class TensorDatasetOp : public OpKernel { +class TensorDatasetOp : public DatasetOpKernel { public: - explicit TensorDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { // Create a new TensorDatasetOp::Dataset, insert it in the step // container, and return it as the output. OpInputList inputs; @@ -40,13 +40,7 @@ class TensorDatasetOp : public OpKernel { for (const Tensor& t : inputs) { components.push_back(t); } - DatasetBase* dataset = new Dataset(std::move(components)); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(std::move(components)); } private: diff --git a/tensorflow/core/kernels/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/tensor_slice_dataset_op.cc index fc70d2ecc5..69dd1584b6 100644 --- a/tensorflow/core/kernels/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_slice_dataset_op.cc @@ -24,11 +24,12 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class TensorSliceDatasetOp : public OpKernel { +class TensorSliceDatasetOp : public DatasetOpKernel { public: - explicit TensorSliceDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit TensorSliceDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { // Create a new TensorDatasetOp::Dataset, insert it in the step // container, and return it as the output. OpInputList inputs; @@ -49,13 +50,7 @@ class TensorSliceDatasetOp : public OpKernel { errors::InvalidArgument( "All components must have the same size in the 0th dimension")); } - DatasetBase* dataset = new Dataset(std::move(components)); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset)); - output->flat<ResourceHandle>()(0) = handle; + *output = new Dataset(std::move(components)); } private: diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc index e7fc9bc6b1..325e8a5df9 100644 --- a/tensorflow/core/kernels/zip_dataset_op.cc +++ b/tensorflow/core/kernels/zip_dataset_op.cc @@ -24,11 +24,11 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class ZipDatasetOp : public OpKernel { +class ZipDatasetOp : public DatasetOpKernel { public: - explicit ZipDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit ZipDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { std::vector<DatasetBase*> inputs; Status s; for (size_t i = 0; i < ctx->num_inputs(); ++i) { @@ -43,17 +43,7 @@ class ZipDatasetOp : public OpKernel { } if (s.ok()) { - DatasetBase* dataset = new Dataset(inputs); - Tensor* output = nullptr; - s = ctx->allocate_output(0, TensorShape({}), &output); - if (s.ok()) { - ResourceHandle handle = MakeResourceHandle<DatasetBase>( - ctx, ctx->step_container()->name(), name()); - s = CreateResource(ctx, handle, dataset); - if (s.ok()) { - output->flat<ResourceHandle>()(0) = handle; - } - } + *output = new Dataset(inputs); } // TODO(mrry): Implement a container that acts as a |