diff options
Diffstat (limited to 'tensorflow/contrib/data/ops/dataset_ops.cc')
-rw-r--r-- | tensorflow/contrib/data/ops/dataset_ops.cc | 232 |
1 files changed, 232 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc new file mode 100644 index 0000000000..1574384cb2 --- /dev/null +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -0,0 +1,232 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- + +// The ops in this section can be composed to define an input +// pipeline. Each op produces a DT_VARIANT tensor that represents +// a DAG of "dataset" objects. An "dataset" object can be converted +// to a stateful "iterator" by passing the "dataset" to the +// "MakeIterator" op. +// +// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are +// not presently serializable. To avoid issues with constant folding, ensure +// that any "source dataset" ops (i.e. ops that output a dataset and do not +// take one as input) are marked "stateful". + +REGISTER_OP("IgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +)doc"); + +REGISTER_OP("MapAndBatchDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("batch_size: int64") + .Input("num_parallel_batches: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + +batch_size: A scalar representing the number of elements to accumulate in a + batch. It determines the number of concurrent invocations of `f` that process + elements from `input_dataset` in parallel. +num_parallel_batches: A scalar representing the number of batches to create in + parallel. Processing multiple batches in parallel benefits workloads prone to + stragglers. +)doc"); + +REGISTER_OP("ScanDataset") + .Input("input_dataset: variant") + .Input("initial_state: Tstate") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("f: func") + .Attr("Tstate: list(type) >= 1") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset successively reduces `f` over the elements of `input_dataset`. +)doc"); + +REGISTER_OP("ParallelInterleaveDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("sloppy: bool") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset variant that contains elements matching + `output_types` and `output_shapes`. +)doc"); + +REGISTER_OP("GroupByWindowDataset") + .Input("input_dataset: variant") + .Input("key_func_other_arguments: Tkey_func_other_arguments") + .Input("reduce_func_other_arguments: Treduce_func_other_arguments") + .Input( + "window_size_func_other_arguments: Twindow_size_func_other_arguments") + .Output("handle: variant") + .Attr("key_func: func") + .Attr("reduce_func: func") + .Attr("window_size_func: func") + .Attr("Tkey_func_other_arguments: list(type) >= 0") + .Attr("Treduce_func_other_arguments: list(type) >= 0") + .Attr("Twindow_size_func_other_arguments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that computes a windowed group-by on `input_dataset`. + +// TODO(mrry): Support non-int64 keys. + +key_func: A function mapping an element of `input_dataset`, concatenated + with `key_func_other_arguments` to a scalar value of type DT_INT64. +)doc"); + +REGISTER_OP("DenseToSparseBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("row_shape: int64") + .Output("handle: variant") + // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. + .Attr("output_types: list(type) >= 1") + // NOTE(mrry): the 1st and 2nd elements will be vectors. + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that yields a SparseTensor for each element of the input. + +input_dataset: A handle to an input dataset. Must have a single component. +batch_size: A scalar representing the number of elements to accumulate in a + batch. +row_shape: A vector representing the dense shape of each row in the produced + SparseTensor. The shape may be partially specified, using `-1` to indicate + that a particular dimension should use the maximum size of all batch elements. +)doc"); + +REGISTER_OP("SqlDataset") + .Input("driver_name: string") + .Input("data_source_name: string") + .Input("query: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that executes a SQL query and emits rows of the result set. + +driver_name: The database type. Currently, the only supported type is 'sqlite'. +data_source_name: A connection string to connect to the database. +query: A SQL query to execute. +)doc"); + +REGISTER_OP("DatasetToSingleElement") + .Input("dataset: variant") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); + }) + .Doc(R"doc( +Outputs the single element from the given dataset. + +dataset: A handle to a dataset that contains a single element. +components: The components of the single element of `input`. +)doc"); + +REGISTER_OP("SerializeIterator") + .Input("resource_handle: resource") + .Output("serialized: variant") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Converts the given `resource_handle` representing an iterator to a variant tensor. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +REGISTER_OP("DeserializeIterator") + .Input("resource_handle: resource") + .Input("serialized: variant") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Converts the given variant tensor to an iterator and stores it in the given resource. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +} // namespace tensorflow |