diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 09bcf8a9e7..c317e9e3b4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -195,17 +195,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } break; case HloOpcode::kSort: { - TF_RET_CHECK(proto.operand_ids_size() == 1 || - proto.operand_ids_size() == 2) - << "Sort instruction should have 1 or 2 operands but has " + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "Sort instruction should have at least 1 operand but has " << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; - HloInstruction* keys = operands(0); - HloInstruction* values = - proto.operand_ids_size() == 2 ? operands(1) : nullptr; - instruction = - CreateSort(proto.shape(), proto.dimensions(0), keys, values); + auto sort_operands = all_operands(); + HloInstruction* keys = sort_operands[0]; + instruction = CreateSort( + proto.shape(), proto.dimensions(0), keys, + absl::Span<HloInstruction* const>(sort_operands).subspan(1)); break; } case HloOpcode::kTranspose: @@ -1078,7 +1077,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) { + absl::Span<HloInstruction* const> values) { return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values); } |