aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/ops/dataset_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/ops/dataset_ops.cc')
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc232
1 files changed, 232 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
new file mode 100644
index 0000000000..1574384cb2
--- /dev/null
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -0,0 +1,232 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+
+// The ops in this section can be composed to define an input
+// pipeline. Each op produces a DT_VARIANT tensor that represents
+// a DAG of "dataset" objects. An "dataset" object can be converted
+// to a stateful "iterator" by passing the "dataset" to the
+// "MakeIterator" op.
+//
+// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are
+// not presently serializable. To avoid issues with constant folding, ensure
+// that any "source dataset" ops (i.e. ops that output a dataset and do not
+// take one as input) are marked "stateful".
+
+REGISTER_OP("IgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+)doc");
+
+REGISTER_OP("MapAndBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("batch_size: int64")
+ .Input("num_parallel_batches: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
+batches `batch_size` of them.
+
+Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
+to `batch_size * num_parallel_batches` copies of `f` in parallel.
+
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch. It determines the number of concurrent invocations of `f` that process
+ elements from `input_dataset` in parallel.
+num_parallel_batches: A scalar representing the number of batches to create in
+ parallel. Processing multiple batches in parallel benefits workloads prone to
+ stragglers.
+)doc");
+
+REGISTER_OP("ScanDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset successively reduces `f` over the elements of `input_dataset`.
+)doc");
+
+REGISTER_OP("ParallelInterleaveDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("sloppy: bool")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset`.
+
+The resulting dataset is similar to the `InterleaveDataset`, with the exception
+that if retrieving the next value from a dataset would cause the requester to
+block, it will skip that input dataset. This dataset is especially useful
+when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
+allows the training step to proceed so long as some data is available.
+
+!! WARNING !! This dataset is not deterministic!
+
+f: A function mapping elements of `input_dataset`, concatenated with
+ `other_arguments`, to a Dataset variant that contains elements matching
+ `output_types` and `output_shapes`.
+)doc");
+
+REGISTER_OP("GroupByWindowDataset")
+ .Input("input_dataset: variant")
+ .Input("key_func_other_arguments: Tkey_func_other_arguments")
+ .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+ .Input(
+ "window_size_func_other_arguments: Twindow_size_func_other_arguments")
+ .Output("handle: variant")
+ .Attr("key_func: func")
+ .Attr("reduce_func: func")
+ .Attr("window_size_func: func")
+ .Attr("Tkey_func_other_arguments: list(type) >= 0")
+ .Attr("Treduce_func_other_arguments: list(type) >= 0")
+ .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that computes a windowed group-by on `input_dataset`.
+
+// TODO(mrry): Support non-int64 keys.
+
+key_func: A function mapping an element of `input_dataset`, concatenated
+ with `key_func_other_arguments` to a scalar value of type DT_INT64.
+)doc");
+
+REGISTER_OP("DenseToSparseBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("row_shape: int64")
+ .Output("handle: variant")
+ // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
+ .Attr("output_types: list(type) >= 1")
+ // NOTE(mrry): the 1st and 2nd elements will be vectors.
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that yields a SparseTensor for each element of the input.
+
+input_dataset: A handle to an input dataset. Must have a single component.
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch.
+row_shape: A vector representing the dense shape of each row in the produced
+ SparseTensor. The shape may be partially specified, using `-1` to indicate
+ that a particular dimension should use the maximum size of all batch elements.
+)doc");
+
+REGISTER_OP("SqlDataset")
+ .Input("driver_name: string")
+ .Input("data_source_name: string")
+ .Input("query: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that executes a SQL query and emits rows of the result set.
+
+driver_name: The database type. Currently, the only supported type is 'sqlite'.
+data_source_name: A connection string to connect to the database.
+query: A SQL query to execute.
+)doc");
+
+REGISTER_OP("DatasetToSingleElement")
+ .Input("dataset: variant")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs the single element from the given dataset.
+
+dataset: A handle to a dataset that contains a single element.
+components: The components of the single element of `input`.
+)doc");
+
+REGISTER_OP("SerializeIterator")
+ .Input("resource_handle: resource")
+ .Output("serialized: variant")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Converts the given `resource_handle` representing an iterator to a variant tensor.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
+REGISTER_OP("DeserializeIterator")
+ .Input("resource_handle: resource")
+ .Input("serialized: variant")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Converts the given variant tensor to an iterator and stores it in the given resource.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
+} // namespace tensorflow