diff options
author | 2018-08-28 11:37:18 -0700 | |
---|---|---|
committer | 2018-08-28 11:41:45 -0700 | |
commit | 6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch) | |
tree | af6c7c66f13892c9df05e55c83896ae4bd67a77e /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 13c7499d5454b870eb3604d6b0ca241685cabe18 (diff) |
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index b407cfeb50..b93c758937 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -416,6 +416,58 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( replica_groups()); } +HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs) + : HloInstruction(HloOpcode::kCollectivePermute, shape), + source_target_pairs_(source_target_pairs) { + AppendOperand(operand); +} + +HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (const auto& pair : source_target_pairs()) { + auto* proto_pair = proto.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + return proto; +} + +std::vector<string> +HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector<string> result; + std::vector<string> strs; + for (const auto& pair : source_target_pairs()) { + strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); + } + result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}")); + return result; +} + +bool HloCollectivePermuteInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + /*eq_computations*/) const { + const auto& casted_other = + static_cast<const HloCollectivePermuteInstruction&>(other); + return ContainersEqual( + source_target_pairs(), casted_other.source_target_pairs(), + [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) { + return a == b; + }); +} + +std::unique_ptr<HloInstruction> +HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique<HloCollectivePermuteInstruction>( + shape, new_operands[0], source_target_pairs()); +} + HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) |