aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-07-10 13:04:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 13:08:18 -0700
commit1882427291f212cd04d09f2b35af4a78f6d771b5 (patch)
treec6172cec133d7f7fed37757db67822e074d658a9 /tensorflow/compiler/xla/service/hlo_instructions.cc
parentccbbd4484a29fb3ee7d0d67abcebafdb48c9059b (diff)
[XLA] Generalize sort semantics to Rk.
PiperOrigin-RevId: 203997296
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc40
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index c160647f7a..7ea42caa7b 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -469,6 +469,46 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], new_operands[1], dimensions(), to_apply());
}
+HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values)
+ : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) {
+ AppendOperand(keys);
+ if (values) {
+ AppendOperand(values);
+ }
+}
+
+HloInstructionProto HloSortInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloSortInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloSortInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ HloInstruction* keys = new_operands[0];
+ HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
+ return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values);
+}
+
HloTransposeInstruction::HloTransposeInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)