diff options
Diffstat (limited to 'tensorflow/core/ops/dataset_ops.cc')
-rw-r--r-- | tensorflow/core/ops/dataset_ops.cc | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 1a5ad8f421..9d2b3af51d 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -210,6 +210,7 @@ REGISTER_OP("ParallelMapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("MapAndBatchDataset") @@ -326,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ParallelInterleaveDatasetV2") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("num_parallel_calls: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("GroupByReducerDataset") .Input("input_dataset: variant") .Input("key_func_other_arguments: Tkey_func_other_arguments") @@ -867,7 +881,7 @@ REGISTER_OP("MapDefun") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector<TensorShape> output_shapes; + std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( @@ -877,6 +891,10 @@ REGISTER_OP("MapDefun") int64 dim_zero = -1; for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + if (c->Rank(c->input(i)) == 0) { + return errors::InvalidArgument( + "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { if (dim_zero == -1) { |