diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-07-10 13:04:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 13:08:18 -0700 |
commit | 1882427291f212cd04d09f2b35af4a78f6d771b5 (patch) | |
tree | c6172cec133d7f7fed37757db67822e074d658a9 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | ccbbd4484a29fb3ee7d0d67abcebafdb48c9059b (diff) |
[XLA] Generalize sort semantics to Rk.
PiperOrigin-RevId: 203997296
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index c160647f7a..7ea42caa7b 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -469,6 +469,46 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], new_operands[1], dimensions(), to_apply()); } +HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, + HloInstruction* keys, + HloInstruction* values) + : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { + AppendOperand(keys); + if (values) { + AppendOperand(values); + } +} + +HloInstructionProto HloSortInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloSortInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloSortInstruction&>(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + HloInstruction* keys = new_operands[0]; + HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; + return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values); +} + HloTransposeInstruction::HloTransposeInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) |