aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_verifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc22
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 620458855f..a1f668921d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -266,18 +266,20 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
}
Status ShapeVerifier::HandleSort(HloInstruction* sort) {
- if (sort->operand_count() < 1 || sort->operand_count() > 2) {
- return InternalError("Expected 1 or 2 operands for %s instruction: %s",
+ if (sort->operand_count() < 1) {
+ return InternalError("Expected at least 1 operand for %s instruction: %s",
HloOpcodeString(sort->opcode()), sort->ToString());
}
- if (sort->operand_count() == 2 &&
- !ShapeUtil::SameDimensions(sort->operand(0)->shape(),
- sort->operand(1)->shape())) {
- return InternalError(
- "Expected sort to have to have the same dimensions for the keys and "
- "the values. Keys shape is: %s\n, Values shape is: %s",
- StringifyShape(sort->operand(0)->shape()),
- StringifyShape(sort->operand(1)->shape()));
+ for (int64 operand = 1; operand < sort->operand_count(); ++operand) {
+ if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
+ sort->operand(operand)->shape())) {
+ return InternalError(
+ "Expected sort to have to have the same dimensions for the keys "
+ "and the values. Keys shape is: %s\n, Values shape (operand index "
+ "%lld) is: %s",
+ StringifyShape(sort->operand(0)->shape()), operand,
+ StringifyShape(sort->operand(operand)->shape()));
+ }
}
return CheckVariadicShape(sort);
}