diff options
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 49 |
1 files changed, 26 insertions, 23 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index d6ae75473f..ef8ad7972c 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -631,38 +631,41 @@ REGISTER_OP("SplitV") return errors::InvalidArgument( "Length of size_splits should be equal to num_outputs"); } - int64_t cumsum_outputs = 0; + int64_t total_size = 0; bool has_neg_one = false; + for (const auto size : data) { + if (size == -1) { + if (has_neg_one) { + return errors::InvalidArgument( + "size_splits can only have one -1"); + } + has_neg_one = true; + } else { + total_size += size; + } + } + auto split_dim_size = c->Value(c->Dim(input, split_dim)); // If the sizes of the splits are known, then // make sure that the sizes add up to the expected // dimension size, with the possibility of a -1. // Specify the full output shapes. for (int i = 0; i < num_outputs; ++i) { - output_shape = c->UnknownShapeOfRank(rank); - TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim, - c->MakeDim(data[i]), &output_shape)); + auto size = data[i]; + if (data[i] == -1 && c->ValueKnown(split_dim_size)) { + size = split_dim_size - total_size; + } + TF_RETURN_IF_ERROR( + c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape)); c->set_output(i, output_shape); - if (data[i] == -1 && !has_neg_one) - has_neg_one = true; - else if (data[i] == -1 && has_neg_one) - return errors::InvalidArgument("size_splits can only have one -1"); - else - cumsum_outputs += data[i]; } - auto split_dim_size = c->Value(c->Dim(input, split_dim)); - if (has_neg_one) { - if (cumsum_outputs < split_dim_size) - cumsum_outputs = split_dim_size; - else - cumsum_outputs = split_dim_size + 1; + if (c->ValueKnown(split_dim_size)) { + if (has_neg_one ? total_size > split_dim_size + : total_size != split_dim_size) { + return errors::InvalidArgument( + "can't split axis of size ", split_dim_size, + " into pieces of size [", str_util::Join(data, ","), "]"); + } } - if (c->ValueKnown(c->Dim(input, split_dim)) && - cumsum_outputs != c->Value(c->Dim(input, split_dim))) - return errors::InvalidArgument( - "Sum of output sizes must match " - "the size of the original Tensor along the split dimension " - "or the sum of the positive sizes must be less if it contains a " - "-1"); } return Status::OK(); |