aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/dataset_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r--tensorflow/core/ops/dataset_ops.cc20
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 1a5ad8f421..9d2b3af51d 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -210,6 +210,7 @@ REGISTER_OP("ParallelMapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapAndBatchDataset")
@@ -326,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ParallelInterleaveDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("num_parallel_calls: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@@ -867,7 +881,7 @@ REGISTER_OP("MapDefun")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- std::vector<TensorShape> output_shapes;
+ std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
@@ -877,6 +891,10 @@ REGISTER_OP("MapDefun")
int64 dim_zero = -1;
for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ if (c->Rank(c->input(i)) == 0) {
+ return errors::InvalidArgument(
+ "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ }
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
if (dim_zero == -1) {