aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 11:37:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 11:41:45 -0700
commit6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch)
treeaf6c7c66f13892c9df05e55c83896ae4bd67a77e /tensorflow/compiler/xla/service/hlo_instruction.h
parent13c7499d5454b870eb3604d6b0ca241685cabe18 (diff)
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h14
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 01437f66cd..b393635e9d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -470,6 +470,15 @@ class HloInstruction {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
const std::vector<ReplicaGroup>& replica_groups);
+ // Creates a communitation instructions that permutes data cross replicas.
+ // Data is sent/received according to the (source_replica_id,
+ // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
+ // target_replica_id in any pair, the output on that replica is a tensor
+ // conssits of 0(s) in `shape`.
+ static std::unique_ptr<HloInstruction> CreateCollectivePermute(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
@@ -1429,9 +1438,12 @@ class HloInstruction {
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const;
- // Delegates to HloAllToAllInstruction::replica_groups.
+ // Delegates to HloCollectiveInstruction::replica_groups.
const std::vector<ReplicaGroup>& replica_groups() const;
+ // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
+ const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
+
// Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
string cross_replica_sum_barrier() const;
void set_cross_replica_sum_barrier(const string& barrier);