aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-28 08:38:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 08:46:34 -0700
commitc7bb3c3d65e4e064d53630d4b524522eed6f3f44 (patch)
tree1fd0b73ab916093c80dcd289154035bba5fb393d /tensorflow/core/ops
parente06783e7bb80f664c7ec9be90680ac6ddcbd598f (diff)
[tf.data] Move `tf.contrib.data` C++ code to a core "experimental" directory.
NOTE: All ops and kernels previously previously defined in tensorflow/contrib/data have had their name prefixed with "Experimental" to indicate that they are not (yet) stable, and thus not subject to backwards or forwards compatibility guarantees. PiperOrigin-RevId: 214940819
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc207
1 files changed, 207 insertions, 0 deletions
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
new file mode 100644
index 0000000000..f6bd5dce26
--- /dev/null
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -0,0 +1,207 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ExperimentalDirectedInterleaveDataset")
+ .Input("selector_input_dataset: variant")
+ .Input("data_input_datasets: N * variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalCSVDataset")
+ .Input("filenames: string")
+ .Input("compression_type: string")
+ .Input("buffer_size: int64")
+ .Input("header: bool")
+ .Input("field_delim: string")
+ .Input("use_quote_delim: bool")
+ .Input("na_value: string")
+ .Input("select_cols: int64")
+ .Input("record_defaults: output_types")
+ .Output("handle: variant")
+ .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // `compression_type`, `buffer_size`, `header`, `field_delim`,
+ // `use_quote_delim`, `na_value` must be scalars
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
+ // `select_cols` must be a vector
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
+ // `record_defaults` must be lists of scalars
+ for (size_t i = 8; i < c->num_inputs(); ++i) {
+ shape_inference::ShapeHandle v;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+ return errors::InvalidArgument(
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
+ }
+ }
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalIgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalUniqueDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIteratorGetDevice")
+ .Input("resource: resource")
+ .Output("device: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResource")
+ .Input("string_arg: string")
+ .Input("target_device: string")
+ .Output("resource: resource")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("f: func")
+ .Attr("buffer_size: int")
+ .Attr("output_types: list(type)")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext")
+ .Input("function_buffer_resource: resource")
+ .Attr("output_types: list(type)")
+ .Output("output: output_types")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
+ .Input("function_buffer_resource: resource")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalThreadPoolDataset")
+ .Input("input_dataset: variant")
+ .Input("thread_pool: resource")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalThreadPoolHandle")
+ .Output("handle: resource")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Attr("num_threads: int")
+ .Attr("max_intra_op_parallelism: int = 1")
+ .Attr("display_name: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''");
+
+REGISTER_OP("ExperimentalAssertNextDataset")
+ .Input("input_dataset: variant")
+ .Input("transformations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // transformations should be a vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalLMDBDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ExperimentalIndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn);
+
+} // namespace tensorflow