aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc49
1 files changed, 26 insertions, 23 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index d6ae75473f..ef8ad7972c 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -631,38 +631,41 @@ REGISTER_OP("SplitV")
return errors::InvalidArgument(
"Length of size_splits should be equal to num_outputs");
}
- int64_t cumsum_outputs = 0;
+ int64_t total_size = 0;
bool has_neg_one = false;
+ for (const auto size : data) {
+ if (size == -1) {
+ if (has_neg_one) {
+ return errors::InvalidArgument(
+ "size_splits can only have one -1");
+ }
+ has_neg_one = true;
+ } else {
+ total_size += size;
+ }
+ }
+ auto split_dim_size = c->Value(c->Dim(input, split_dim));
// If the sizes of the splits are known, then
// make sure that the sizes add up to the expected
// dimension size, with the possibility of a -1.
// Specify the full output shapes.
for (int i = 0; i < num_outputs; ++i) {
- output_shape = c->UnknownShapeOfRank(rank);
- TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim,
- c->MakeDim(data[i]), &output_shape));
+ auto size = data[i];
+ if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
+ size = split_dim_size - total_size;
+ }
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
c->set_output(i, output_shape);
- if (data[i] == -1 && !has_neg_one)
- has_neg_one = true;
- else if (data[i] == -1 && has_neg_one)
- return errors::InvalidArgument("size_splits can only have one -1");
- else
- cumsum_outputs += data[i];
}
- auto split_dim_size = c->Value(c->Dim(input, split_dim));
- if (has_neg_one) {
- if (cumsum_outputs < split_dim_size)
- cumsum_outputs = split_dim_size;
- else
- cumsum_outputs = split_dim_size + 1;
+ if (c->ValueKnown(split_dim_size)) {
+ if (has_neg_one ? total_size > split_dim_size
+ : total_size != split_dim_size) {
+ return errors::InvalidArgument(
+ "can't split axis of size ", split_dim_size,
+ " into pieces of size [", str_util::Join(data, ","), "]");
+ }
}
- if (c->ValueKnown(c->Dim(input, split_dim)) &&
- cumsum_outputs != c->Value(c->Dim(input, split_dim)))
- return errors::InvalidArgument(
- "Sum of output sizes must match "
- "the size of the original Tensor along the split dimension "
- "or the sum of the positive sizes must be less if it contains a "
- "-1");
}
return Status::OK();