aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/ops/dataset_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/ops/dataset_ops.cc')
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc99
1 files changed, 94 insertions, 5 deletions
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index 8413fcaf87..66a7c7fdcd 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -36,6 +36,7 @@ data_input_datasets: `N` datasets with the same type that will be interleaved
REGISTER_OP("CSVDataset")
.Input("filenames: string")
+ .Input("compression_type: string")
.Input("buffer_size: int64")
.Input("header: bool")
.Input("field_delim: string")
@@ -52,17 +53,18 @@ REGISTER_OP("CSVDataset")
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
- // `buffer_size`, `header`, `field_delim`, `use_quote_delim`,
- // `na_value` must be scalars
+ // `compression_type`, `buffer_size`, `header`, `field_delim`,
+ // `use_quote_delim`, `na_value` must be scalars
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
// `select_cols` must be a vector
- TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused));
- // `record_defaults` must be a list of scalars...?
- for (size_t i = 7; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
+ // `record_defaults` must be lists of scalars
+ for (size_t i = 8; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused));
}
return shape_inference::ScalarShape(c);
@@ -143,6 +145,80 @@ Resets the FunctionBufferingResource.
function_buffer_resource: The FunctionBufferingResource handle.
)doc");
+REGISTER_OP("MultiDeviceIterator")
+ .Output("handle: resource")
+ .Attr("devices: list(string) >= 1")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Doc(R"doc(
+Creates a MultiDeviceIterator resource.
+
+handle: Handle to the resource created.
+devices: A list of devices the iterator works across.
+shared_name: If non-empty, this resource will be shared under the given name
+ across multiple sessions.
+container: If non-empty, this resource is placed in the given container.
+ Otherwise, a default container is used.
+output_types: The type list for the return values.
+output_shapes: The list of shapes being produced.
+)doc");
+
+REGISTER_OP("MultiDeviceIteratorInit")
+ .Input("dataset: variant")
+ .Input("multi_device_iterator: resource")
+ .Output("incarnation_id: int64")
+ .Doc(R"doc(
+Initializes the multi device iterator with the given dataset.
+incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator
+ is running.
+dataset: Dataset to be iterated upon.
+multi_device_iterator: A MultiDeviceIteratorResource.
+)doc");
+
+REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
+ .Input("multi_device_iterator: resource")
+ .Input("shard_num: int32")
+ .Input("incarnation_id: int64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Doc(R"doc(
+Gets next element for the provided shard number.
+
+multi_device_iterator: A MultiDeviceIterator resource.
+shard_num: Integer representing which shard to fetch data for.
+incarnation_id: Which incarnation of the MultiDeviceIterator is running.
+components: Result of the get_next on the dataset.
+output_types: The type list for the return values.
+output_shapes: The list of shapes being produced.
+)doc");
+
+REGISTER_OP("MultiDeviceIteratorToStringHandle")
+ .Input("multi_device_iterator: resource")
+ .Output("string_handle: string")
+ .Doc(R"doc(
+Produces a string handle for the given MultiDeviceIterator.
+
+multi_device_iterator: A MultiDeviceIterator resource.
+string_handle: A string representing the resource.
+)doc");
+
+REGISTER_OP("MultiDeviceIteratorFromStringHandle")
+ .Input("string_handle: string")
+ .Output("multi_device_iterator: resource")
+ .Attr("output_types: list(type) >= 0 = []")
+ .Attr("output_shapes: list(shape) >= 0 = []")
+ .Doc(R"doc(
+Generates a MultiDeviceIterator resource from its provided string handle.
+
+string_handle: String representing the resource.
+multi_device_iterator: A MultiDeviceIterator resource.
+output_types: The type list for the return values.
+output_shapes: The list of shapes being produced.
+)doc");
+
REGISTER_OP("ThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
@@ -175,4 +251,17 @@ display_name: A human-readable name for the threads that may be visible in
some visualizations.
)doc");
+REGISTER_OP("AssertNextDataset")
+ .Input("input_dataset: variant")
+ .Input("transformations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // transformations should be a vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
} // namespace tensorflow