diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 11:37:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 11:41:45 -0700 |
commit | 6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch) | |
tree | af6c7c66f13892c9df05e55c83896ae4bd67a77e /tensorflow/compiler/xla/service/hlo_instruction.h | |
parent | 13c7499d5454b870eb3604d6b0ca241685cabe18 (diff) |
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 01437f66cd..b393635e9d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -470,6 +470,15 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, const std::vector<ReplicaGroup>& replica_groups); + // Creates a communitation instructions that permutes data cross replicas. + // Data is sent/received according to the (source_replica_id, + // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a + // target_replica_id in any pair, the output on that replica is a tensor + // conssits of 0(s) in `shape`. + static std::unique_ptr<HloInstruction> CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, @@ -1429,9 +1438,12 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllToAllInstruction::replica_groups. + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector<ReplicaGroup>& replica_groups() const; + // Delegates to HloCollectivePermuteInstruction::source_target_pairs. + const std::vector<std::pair<int64, int64>>& source_target_pairs() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); |