aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-07-10 13:04:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 13:08:18 -0700
commit1882427291f212cd04d09f2b35af4a78f6d771b5 (patch)
treec6172cec133d7f7fed37757db67822e074d658a9 /tensorflow/compiler/xla/service/hlo_parser.cc
parentccbbd4484a29fb3ee7d0d67abcebafdb48c9059b (diff)
[XLA] Generalize sort semantics to Rk.
PiperOrigin-RevId: 203997296
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 2c3dab0a45..f162d52d3c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -632,17 +632,22 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kSort: {
auto loc = lexer_.GetLoc();
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+
+ optional<std::vector<tensorflow::int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &dimensions};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
+ dimensions->size() != 1) {
return false;
}
switch (operands.size()) {
case 1:
- instruction = builder->AddInstruction(
- HloInstruction::CreateSort(shape, /*keys=*/operands[0]));
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape, dimensions->at(0), /*keys=*/operands[0]));
break;
case 2:
instruction = builder->AddInstruction(HloInstruction::CreateSort(
- shape,
+ shape, dimensions->at(0),
/*keys=*/operands[0], /*values=*/operands[1]));
break;
default: