aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc47
1 files changed, 36 insertions, 11 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index a3b0512304..4741bc968a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -416,24 +416,49 @@ REGISTER_OP("SplitV")
.Attr("T: type")
.Attr("Tlen: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
- ShapeHandle unused;
+ DimensionHandle split_dimension;
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension));
int32 num_outputs = c->num_outputs();
- // Return unknown shapes with the same rank as the input
- // or unknown rank if input's rank isn't known
- // can't determine exact shapes until runtime because
- // we don't know where the tensor containing the split sizes
- // is located
- int32 rank = c->Rank(c->input(0));
+ ShapeHandle input = c->input(0);
+ int32 rank = c->Rank(input);
ShapeHandle output_shape;
+ const Tensor* size_splits = c->input_tensor(1);
if (rank == InferenceContext::kUnknownRank) {
+ // If the rank of input tensor is unknown, then return unkown shapes.
output_shape = c->UnknownShape();
+ for (int i = 0; i < num_outputs; ++i) {
+ c->set_output(i, output_shape);
+ }
} else if (rank == 0) {
+ // Throw error if input is a scalar.
return errors::InvalidArgument("Can't split scalars");
- } else {
+ } else if (size_splits == nullptr || !c->ValueKnown(split_dimension)) {
+ // If split dimension or tensor containing the split sizes is unkown,
+ // then return unknown shapes of same rank as input.
output_shape = c->UnknownShapeOfRank(rank);
- }
- for (int i = 0; i < num_outputs; ++i) {
- c->set_output(i, output_shape);
+ for (int i = 0; i < num_outputs; ++i) {
+ c->set_output(i, output_shape);
+ }
+ } else {
+ // Determine the output shape if split dimension and split sizes are known
+ int64 split_dim = c->Value(split_dimension);
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
+ std::vector<int64> data;
+ if (size_splits->dtype() == DT_INT32) {
+ data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0));
+ } else {
+ data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0));
+ }
+ if (num_outputs != data.size()) {
+ return errors::InvalidArgument(
+ "Length of size_splits should be equal to num_outputs");
+ }
+ for (int i = 0; i < num_outputs; ++i) {
+ output_shape = c->UnknownShapeOfRank(rank);
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape));
+ c->set_output(i, output_shape);
+ }
}
return Status::OK();