aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc17
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 09bcf8a9e7..c317e9e3b4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -195,17 +195,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
break;
case HloOpcode::kSort: {
- TF_RET_CHECK(proto.operand_ids_size() == 1 ||
- proto.operand_ids_size() == 2)
- << "Sort instruction should have 1 or 2 operands but has "
+ TF_RET_CHECK(proto.operand_ids_size() >= 1)
+ << "Sort instruction should have at least 1 operand but has "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.dimensions().size() == 1)
<< "Sort instruction should have 1 dimension";
- HloInstruction* keys = operands(0);
- HloInstruction* values =
- proto.operand_ids_size() == 2 ? operands(1) : nullptr;
- instruction =
- CreateSort(proto.shape(), proto.dimensions(0), keys, values);
+ auto sort_operands = all_operands();
+ HloInstruction* keys = sort_operands[0];
+ instruction = CreateSort(
+ proto.shape(), proto.dimensions(0), keys,
+ absl::Span<HloInstruction* const>(sort_operands).subspan(1));
break;
}
case HloOpcode::kTranspose:
@@ -1078,7 +1077,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
- HloInstruction* values) {
+ absl::Span<HloInstruction* const> values) {
return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}