diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-07-10 13:04:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 13:08:18 -0700 |
commit | 1882427291f212cd04d09f2b35af4a78f6d771b5 (patch) | |
tree | c6172cec133d7f7fed37757db67822e074d658a9 /tensorflow/compiler/xla/service/hlo_parser.cc | |
parent | ccbbd4484a29fb3ee7d0d67abcebafdb48c9059b (diff) |
[XLA] Generalize sort semantics to Rk.
PiperOrigin-RevId: 203997296
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2c3dab0a45..f162d52d3c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -632,17 +632,22 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kSort: { auto loc = lexer_.GetLoc(); - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + + optional<std::vector<tensorflow::int64>> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs) || + dimensions->size() != 1) { return false; } switch (operands.size()) { case 1: - instruction = builder->AddInstruction( - HloInstruction::CreateSort(shape, /*keys=*/operands[0])); + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), /*keys=*/operands[0])); break; case 2: instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, + shape, dimensions->at(0), /*keys=*/operands[0], /*values=*/operands[1])); break; default: |