aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 11:37:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 11:41:45 -0700
commit6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch)
treeaf6c7c66f13892c9df05e55c83896ae4bd67a77e
parent13c7499d5454b870eb3604d6b0ca241685cabe18 (diff)
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc27
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h21
-rw-r--r--tensorflow/compiler/xla/client/xla_builder_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h32
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc6
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/compiler/xla/xla_data.proto6
22 files changed, 253 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 65b110e285..5e92df2d63 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1971,6 +1971,27 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
});
}
+XlaOp XlaBuilder::CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferCollectivePermuteShape(operand_shape));
+
+ for (const auto& pair : source_target_pairs) {
+ auto* proto_pair = instr.add_source_target_pairs();
+ proto_pair->set_source(pair.first);
+ proto_pair->set_target(pair.second);
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
+ {operand});
+ });
+}
+
XlaOp XlaBuilder::SelectAndScatter(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
@@ -2782,6 +2803,12 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
split_count, replica_groups);
}
+XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs) {
+ return operand.builder()->CollectivePermute(operand, source_target_pairs);
+}
+
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index baa2ae5184..e9d5d3943c 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -715,6 +715,12 @@ class XlaBuilder {
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
+ // Enqueues an operation that do an CollectivePermute of the operand cross
+ // cores.
+ XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
@@ -1260,6 +1266,9 @@ class XlaBuilder {
friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
+ friend XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
friend XlaOp SelectAndScatter(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
@@ -1861,6 +1870,18 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups = {});
+// Enqueues an collective operation that sends and receives data cross replicas.
+//
+// - `source_target_pair`: a list of (source_replica_id, target_replica_id)
+// pairs. For each pair, the operand is sent from source replica to target
+// replica. Note that, 1) any two pairs should not have the same target replica
+// id, and they should not have the same source replica id; 2) if a replica id
+// is not a target in any pair, then the output on that replica is a tensor
+// consists of 0(s) with the same shape as the input.
+XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 49a15ec3b4..7c37ed00cd 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -320,6 +320,15 @@ TEST_F(XlaBuilderTest, AllToAll) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
}
+TEST_F(XlaBuilderTest, CollectivePermute) {
+ XlaBuilder b(TestName());
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
+ CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}});
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute);
+}
+
TEST_F(XlaBuilderTest, ReportError) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 275e6cc61d..f6f8fc5a2a 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -107,6 +107,7 @@ class DfsHloVisitorBase {
virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
+ virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 6ec4893f7a..4f620e4c3a 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
- Status HandleAllToAll(HloInstructionPtr crs) override {
- return DefaultAction(crs);
+ Status HandleAllToAll(HloInstructionPtr hlo) override {
+ return DefaultAction(hlo);
+ }
+ Status HandleCollectivePermute(HloInstructionPtr hlo) override {
+ return DefaultAction(hlo);
}
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 821c599863..58b7af93eb 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 52
+// Next ID: 53
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -173,6 +173,9 @@ message HloInstructionProto {
// Precision configuration for the instruction. Has backend-specific meaning.
xla.PrecisionConfigProto precision_config = 51;
+
+ // Collective permute field.
+ repeated SourceTarget source_target_pairs = 52;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 5add4251ef..0e12a1ee03 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -543,6 +543,10 @@ Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
return Status::OK();
}
+Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
// TODO(b/26346211): Implement better estimates for the RNG cost, since the
// cost changes with the implementation and the distribution. For now, assume
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 1bf1c4a315..c6a2007904 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -72,6 +72,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleFft(const HloInstruction* fft) override;
Status HandleCrossReplicaSum(const HloInstruction* crs) override;
Status HandleAllToAll(const HloInstruction* hlo) override;
+ Status HandleCollectivePermute(const HloInstruction* hlo) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
Status HandleRng(const HloInstruction* random) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index f2f9ed5969..3041d94fa9 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1029,6 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kGray;
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kRecv:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6b4f3c4eb8..c77699a06f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -320,6 +320,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.replica_groups().end()));
break;
}
+ case HloOpcode::kCollectivePermute: {
+ std::vector<std::pair<int64, int64>> source_target_pairs(
+ proto.source_target_pairs_size());
+ for (int i = 0; i < source_target_pairs.size(); i++) {
+ source_target_pairs[i].first = proto.source_target_pairs(i).source();
+ source_target_pairs[i].second = proto.source_target_pairs(i).target();
+ }
+ instruction = CreateCollectivePermute(proto.shape(), operands(0),
+ source_target_pairs);
+ break;
+ }
case HloOpcode::kConvolution:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Convolution instruction should have 2 operands but sees "
@@ -681,6 +692,14 @@ HloInstruction::CreateCrossReplicaSum(
replica_groups);
}
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCollectivePermute(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs) {
+ return absl::make_unique<HloCollectivePermuteInstruction>(
+ shape, operand, source_target_pairs);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const string& config) {
@@ -1154,6 +1173,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kReducePrecision:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
@@ -1622,6 +1642,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
@@ -2275,6 +2296,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleCrossReplicaSum(this);
case HloOpcode::kAllToAll:
return visitor->HandleAllToAll(this);
+ case HloOpcode::kCollectivePermute:
+ return visitor->HandleCollectivePermute(this);
case HloOpcode::kTuple:
return visitor->HandleTuple(this);
case HloOpcode::kMap:
@@ -3189,6 +3212,11 @@ const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
return Cast<HloCollectiveInstruction>(this)->replica_groups();
}
+const std::vector<std::pair<int64, int64>>&
+HloInstruction::source_target_pairs() const {
+ return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
+}
+
string HloInstruction::cross_replica_sum_barrier() const {
return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 01437f66cd..b393635e9d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -470,6 +470,15 @@ class HloInstruction {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
const std::vector<ReplicaGroup>& replica_groups);
+ // Creates a communitation instructions that permutes data cross replicas.
+ // Data is sent/received according to the (source_replica_id,
+ // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
+ // target_replica_id in any pair, the output on that replica is a tensor
+ // conssits of 0(s) in `shape`.
+ static std::unique_ptr<HloInstruction> CreateCollectivePermute(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
@@ -1429,9 +1438,12 @@ class HloInstruction {
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const;
- // Delegates to HloAllToAllInstruction::replica_groups.
+ // Delegates to HloCollectiveInstruction::replica_groups.
const std::vector<ReplicaGroup>& replica_groups() const;
+ // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
+ const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
+
// Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
string cross_replica_sum_barrier() const;
void set_cross_replica_sum_barrier(const string& barrier);
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index b407cfeb50..b93c758937 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -416,6 +416,58 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
replica_groups());
}
+HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs)
+ : HloInstruction(HloOpcode::kCollectivePermute, shape),
+ source_target_pairs_(source_target_pairs) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (const auto& pair : source_target_pairs()) {
+ auto* proto_pair = proto.add_source_target_pairs();
+ proto_pair->set_source(pair.first);
+ proto_pair->set_target(pair.second);
+ }
+ return proto;
+}
+
+std::vector<string>
+HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result;
+ std::vector<string> strs;
+ for (const auto& pair : source_target_pairs()) {
+ strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
+ }
+ result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
+ return result;
+}
+
+bool HloCollectivePermuteInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ /*eq_computations*/) const {
+ const auto& casted_other =
+ static_cast<const HloCollectivePermuteInstruction&>(other);
+ return ContainersEqual(
+ source_target_pairs(), casted_other.source_target_pairs(),
+ [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) {
+ return a == b;
+ });
+}
+
+std::unique_ptr<HloInstruction>
+HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* /*context*/) const {
+ return absl::make_unique<HloCollectivePermuteInstruction>(
+ shape, new_operands[0], source_target_pairs());
+}
+
HloReverseInstruction::HloReverseInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index efdb9e9781..29b187300d 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -290,7 +290,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand,
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
const std::vector<ReplicaGroup>& replica_groups);
private:
@@ -301,6 +301,36 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
HloCloneContext* context) const override;
};
+class HloCollectivePermuteInstruction : public HloInstruction {
+ public:
+ explicit HloCollectivePermuteInstruction(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
+ const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
+ return source_target_pairs_;
+ }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ const std::vector<std::pair<int64, int64>> source_target_pairs_;
+};
+
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index b8f2a21ff9..e6bfb8025d 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -58,6 +58,7 @@ namespace xla {
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil") \
V(kClamp, "clamp") \
+ V(kCollectivePermute, "collective-permute") \
V(kClz, "count-leading-zeros") \
V(kComplex, "complex") \
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index e4edb87aa5..c7a766f4e0 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -702,6 +702,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateAllToAll(shape, operands, replica_groups));
break;
}
+ case HloOpcode::kCollectivePermute: {
+ optional<std::vector<std::vector<int64>>> source_targets;
+ attrs["source_target_pairs"] = {
+ /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ std::vector<std::pair<int64, int64>> pairs(source_targets->size());
+ for (int i = 0; i < pairs.size(); i++) {
+ if ((*source_targets)[i].size() != 2) {
+ return TokenError(
+ "expects 'source_target_pairs=' to be a list of pairs");
+ }
+ pairs[i].first = (*source_targets)[i][0];
+ pairs[i].second = (*source_targets)[i][1];
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCollectivePermute(shape, operands[0], pairs));
+ break;
+ }
case HloOpcode::kReshape: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index b3d3ccda74..b1ef288b8e 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1098,6 +1098,18 @@ ENTRY AllToAllWithSubgroups {
)"
},
+// collective-permute
+{
+"CollectivePermute",
+R"(HloModule CollectivePermute
+
+ENTRY CollectivePermute {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
+}
+
+)"
+},
// Iota
{
"Iota",
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 81ffb5ac43..0ed2c3b449 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -116,6 +116,11 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
ShapeInference::InferAllToAllTupleShape(operand_shapes));
}
+Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
+ return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
+ hlo->operand(0)->shape()));
+}
+
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(),
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index b6093d667c..42e3027bf1 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -47,6 +47,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleFft(HloInstruction* fft) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleAllToAll(HloInstruction* hlo) override;
+ Status HandleCollectivePermute(HloInstruction* hlo) override;
Status HandleReducePrecision(HloInstruction* reduce_precision) override;
Status HandleInfeed(HloInstruction*) override;
Status HandleOutfeed(HloInstruction*) override;
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 6207cdfb0d..83313c7ec1 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -122,6 +122,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kConvolution:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
case HloOpcode::kDomain:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index b04d2a7ba6..a04af8b0aa 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1844,6 +1844,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
}
+/* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape));
+ return shape;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 4974ac9916..235b1a4cf3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -136,6 +136,9 @@ class ShapeInference {
static StatusOr<Shape> InferAllToAllTupleShape(
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ // Infers the shape of a collective permute operation.
+ static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
+
// Infers the shape produced by applying the given reduction computation
// shape to the given input operand shape.
//
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 9451e0c315..aaba5aa92e 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -570,6 +570,12 @@ message ReplicaGroup {
repeated int64 replica_ids = 1;
}
+// Describes the source target pair in the collective permute op.
+message SourceTarget {
+ int64 source = 1;
+ int64 target = 2;
+}
+
// Used to indicate the precision configuration. It has backend specific
// meaning.
message PrecisionConfigProto {