aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-10-09 19:41:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 19:46:05 -0700
commit58fcfc98cd59ae3952399fc55380b8733df08df9 (patch)
tree24d5ac5d6691e73c227f5afa5ef68ba2ecba4ec0 /tensorflow/compiler/xla/service/hlo_instruction.cc
parent93eef55c4d04af24a6c8080f34629db179634f07 (diff)
[XLA] Add documentation and HLO-level support for multi-value sort.
No support in any of the backends, and not yet exposed through XlaBuilder. PiperOrigin-RevId: 216465753
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc17
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 09bcf8a9e7..c317e9e3b4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -195,17 +195,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
break;
case HloOpcode::kSort: {
- TF_RET_CHECK(proto.operand_ids_size() == 1 ||
- proto.operand_ids_size() == 2)
- << "Sort instruction should have 1 or 2 operands but has "
+ TF_RET_CHECK(proto.operand_ids_size() >= 1)
+ << "Sort instruction should have at least 1 operand but has "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.dimensions().size() == 1)
<< "Sort instruction should have 1 dimension";
- HloInstruction* keys = operands(0);
- HloInstruction* values =
- proto.operand_ids_size() == 2 ? operands(1) : nullptr;
- instruction =
- CreateSort(proto.shape(), proto.dimensions(0), keys, values);
+ auto sort_operands = all_operands();
+ HloInstruction* keys = sort_operands[0];
+ instruction = CreateSort(
+ proto.shape(), proto.dimensions(0), keys,
+ absl::Span<HloInstruction* const>(sort_operands).subspan(1));
break;
}
case HloOpcode::kTranspose:
@@ -1078,7 +1077,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
- HloInstruction* values) {
+ absl::Span<HloInstruction* const> values) {
return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}