aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc25
1 files changed, 15 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index e379911462..aa49f98bcf 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1029,17 +1029,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
case HloOpcode::kSort: {
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
- } else if (operand_shapes.size() == 2) {
- if (!ShapeUtil::SameDimensions(*operand_shapes[0],
- *operand_shapes[1])) {
- return InvalidArgument(
- "Sort keys and values dimensions must match. "
- "Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(*operand_shapes[0]),
- ShapeUtil::HumanString(*operand_shapes[1]));
+ } else {
+ for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
+ if (!ShapeUtil::SameDimensions(*operand_shapes[0],
+ *operand_shapes[operand])) {
+ return InvalidArgument(
+ "Sort keys and values dimensions must match. "
+ "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
+ ShapeUtil::HumanString(*operand_shapes[0]), operand,
+ ShapeUtil::HumanString(*operand_shapes[operand]));
+ }
+ }
+ std::vector<Shape> operand_shape_values;
+ for (const Shape* operand_shape : operand_shapes) {
+ operand_shape_values.push_back(*operand_shape);
}
- return ShapeUtil::MakeTupleShape(
- {*operand_shapes[0], *operand_shapes[1]});
+ return ShapeUtil::MakeTupleShape(operand_shape_values);
}
return InvalidArgument("Unexpected number of operands for sort");
}