diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 620458855f..a1f668921d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -266,18 +266,20 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - if (sort->operand_count() < 1 || sort->operand_count() > 2) { - return InternalError("Expected 1 or 2 operands for %s instruction: %s", + if (sort->operand_count() < 1) { + return InternalError("Expected at least 1 operand for %s instruction: %s", HloOpcodeString(sort->opcode()), sort->ToString()); } - if (sort->operand_count() == 2 && - !ShapeUtil::SameDimensions(sort->operand(0)->shape(), - sort->operand(1)->shape())) { - return InternalError( - "Expected sort to have to have the same dimensions for the keys and " - "the values. Keys shape is: %s\n, Values shape is: %s", - StringifyShape(sort->operand(0)->shape()), - StringifyShape(sort->operand(1)->shape())); + for (int64 operand = 1; operand < sort->operand_count(); ++operand) { + if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(operand)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys " + "and the values. Keys shape is: %s\n, Values shape (operand index " + "%lld) is: %s", + StringifyShape(sort->operand(0)->shape()), operand, + StringifyShape(sort->operand(operand)->shape())); + } } return CheckVariadicShape(sort); } |