diff options
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 47 |
1 files changed, 36 insertions, 11 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index a3b0512304..4741bc968a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -416,24 +416,49 @@ REGISTER_OP("SplitV") .Attr("T: type") .Attr("Tlen: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { - ShapeHandle unused; + DimensionHandle split_dimension; + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &split_dimension)); int32 num_outputs = c->num_outputs(); - // Return unknown shapes with the same rank as the input - // or unknown rank if input's rank isn't known - // can't determine exact shapes until runtime because - // we don't know where the tensor containing the split sizes - // is located - int32 rank = c->Rank(c->input(0)); + ShapeHandle input = c->input(0); + int32 rank = c->Rank(input); ShapeHandle output_shape; + const Tensor* size_splits = c->input_tensor(1); if (rank == InferenceContext::kUnknownRank) { + // If the rank of input tensor is unknown, then return unkown shapes. output_shape = c->UnknownShape(); + for (int i = 0; i < num_outputs; ++i) { + c->set_output(i, output_shape); + } } else if (rank == 0) { + // Throw error if input is a scalar. return errors::InvalidArgument("Can't split scalars"); - } else { + } else if (size_splits == nullptr || !c->ValueKnown(split_dimension)) { + // If split dimension or tensor containing the split sizes is unkown, + // then return unknown shapes of same rank as input. output_shape = c->UnknownShapeOfRank(rank); - } - for (int i = 0; i < num_outputs; ++i) { - c->set_output(i, output_shape); + for (int i = 0; i < num_outputs; ++i) { + c->set_output(i, output_shape); + } + } else { + // Determine the output shape if split dimension and split sizes are known + int64 split_dim = c->Value(split_dimension); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); + std::vector<int64> data; + if (size_splits->dtype() == DT_INT32) { + data = AsInt64<int32>(size_splits, size_splits->shape().dim_size(0)); + } else { + data = AsInt64<int64>(size_splits, size_splits->shape().dim_size(0)); + } + if (num_outputs != data.size()) { + return errors::InvalidArgument( + "Length of size_splits should be equal to num_outputs"); + } + for (int i = 0; i < num_outputs; ++i) { + output_shape = c->UnknownShapeOfRank(rank); + TF_RETURN_IF_ERROR( + c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape)); + c->set_output(i, output_shape); + } } return Status::OK(); |