aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 11:37:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 11:41:45 -0700
commit6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch)
treeaf6c7c66f13892c9df05e55c83896ae4bd67a77e /tensorflow/compiler/xla/service/hlo_instructions.cc
parent13c7499d5454b870eb3604d6b0ca241685cabe18 (diff)
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index b407cfeb50..b93c758937 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -416,6 +416,58 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
replica_groups());
}
+HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs)
+ : HloInstruction(HloOpcode::kCollectivePermute, shape),
+ source_target_pairs_(source_target_pairs) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (const auto& pair : source_target_pairs()) {
+ auto* proto_pair = proto.add_source_target_pairs();
+ proto_pair->set_source(pair.first);
+ proto_pair->set_target(pair.second);
+ }
+ return proto;
+}
+
+std::vector<string>
+HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result;
+ std::vector<string> strs;
+ for (const auto& pair : source_target_pairs()) {
+ strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
+ }
+ result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
+ return result;
+}
+
+bool HloCollectivePermuteInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ /*eq_computations*/) const {
+ const auto& casted_other =
+ static_cast<const HloCollectivePermuteInstruction&>(other);
+ return ContainersEqual(
+ source_target_pairs(), casted_other.source_target_pairs(),
+ [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) {
+ return a == b;
+ });
+}
+
+std::unique_ptr<HloInstruction>
+HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* /*context*/) const {
+ return absl::make_unique<HloCollectivePermuteInstruction>(
+ shape, new_operands[0], source_target_pairs());
+}
+
HloReverseInstruction::HloReverseInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)