diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index e379911462..aa49f98bcf 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1029,17 +1029,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kSort: { if (operand_shapes.size() == 1) { return *operand_shapes[0]; - } else if (operand_shapes.size() == 2) { - if (!ShapeUtil::SameDimensions(*operand_shapes[0], - *operand_shapes[1])) { - return InvalidArgument( - "Sort keys and values dimensions must match. " - "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]), - ShapeUtil::HumanString(*operand_shapes[1])); + } else { + for (int64 operand = 1; operand < operand_shapes.size(); ++operand) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[operand])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape (operand index %lld) is: %s", + ShapeUtil::HumanString(*operand_shapes[0]), operand, + ShapeUtil::HumanString(*operand_shapes[operand])); + } + } + std::vector<Shape> operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); } - return ShapeUtil::MakeTupleShape( - {*operand_shapes[0], *operand_shapes[1]}); + return ShapeUtil::MakeTupleShape(operand_shape_values); } return InvalidArgument("Unexpected number of operands for sort"); } |