diff options
author | 2018-08-28 11:37:18 -0700 | |
---|---|---|
committer | 2018-08-28 11:41:45 -0700 | |
commit | 6de10fb253098c9ff65e9d4083c4de84f3ff5f76 (patch) | |
tree | af6c7c66f13892c9df05e55c83896ae4bd67a77e | |
parent | 13c7499d5454b870eb3604d6b0ca241685cabe18 (diff) |
[XLA] Add the xla interface for CollectivePermute.
PiperOrigin-RevId: 210576458
22 files changed, 253 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 65b110e285..5e92df2d63 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1971,6 +1971,27 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, }); } +XlaOp XlaBuilder::CollectivePermute( + const XlaOp& operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCollectivePermuteShape(operand_shape)); + + for (const auto& pair : source_target_pairs) { + auto* proto_pair = instr.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + + return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute, + {operand}); + }); +} + XlaOp XlaBuilder::SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, @@ -2782,6 +2803,12 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, split_count, replica_groups); } +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs) { + return operand.builder()->CollectivePermute(operand, source_target_pairs); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index baa2ae5184..e9d5d3943c 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -715,6 +715,12 @@ class XlaBuilder { int64 concat_dimension, int64 split_count, const std::vector<ReplicaGroup>& replica_groups); + // Enqueues an operation that do an CollectivePermute of the operand cross + // cores. + XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1260,6 +1266,9 @@ class XlaBuilder { friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector<ReplicaGroup>& replica_groups); + friend XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); friend XlaOp SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, @@ -1861,6 +1870,18 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector<ReplicaGroup>& replica_groups = {}); +// Enqueues an collective operation that sends and receives data cross replicas. +// +// - `source_target_pair`: a list of (source_replica_id, target_replica_id) +// pairs. For each pair, the operand is sent from source replica to target +// replica. Note that, 1) any two pairs should not have the same target replica +// id, and they should not have the same source replica id; 2) if a replica id +// is not a target in any pair, then the output on that replica is a tensor +// consists of 0(s) with the same shape as the input. +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 49a15ec3b4..7c37ed00cd 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -320,6 +320,15 @@ TEST_F(XlaBuilderTest, AllToAll) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); } +TEST_F(XlaBuilderTest, CollectivePermute) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 275e6cc61d..f6f8fc5a2a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -107,6 +107,7 @@ class DfsHloVisitorBase { virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6ec4893f7a..4f620e4c3a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleAllToAll(HloInstructionPtr crs) override { - return DefaultAction(crs); + Status HandleAllToAll(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermute(HloInstructionPtr hlo) override { + return DefaultAction(hlo); } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 821c599863..58b7af93eb 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 52 +// Next ID: 53 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -173,6 +173,9 @@ message HloInstructionProto { // Precision configuration for the instruction. Has backend-specific meaning. xla.PrecisionConfigProto precision_config = 51; + + // Collective permute field. + repeated SourceTarget source_target_pairs = 52; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 5add4251ef..0e12a1ee03 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -543,6 +543,10 @@ Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { return Status::OK(); } +Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 1bf1c4a315..c6a2007904 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -72,6 +72,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; + Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleRng(const HloInstruction* random) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index f2f9ed5969..3041d94fa9 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1029,6 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 6b4f3c4eb8..c77699a06f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -320,6 +320,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.replica_groups().end())); break; } + case HloOpcode::kCollectivePermute: { + std::vector<std::pair<int64, int64>> source_target_pairs( + proto.source_target_pairs_size()); + for (int i = 0; i < source_target_pairs.size(); i++) { + source_target_pairs[i].first = proto.source_target_pairs(i).source(); + source_target_pairs[i].second = proto.source_target_pairs(i).target(); + } + instruction = CreateCollectivePermute(proto.shape(), operands(0), + source_target_pairs); + break; + } case HloOpcode::kConvolution: TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " @@ -681,6 +692,14 @@ HloInstruction::CreateCrossReplicaSum( replica_groups); } +/* static */ std::unique_ptr<HloInstruction> +HloInstruction::CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs) { + return absl::make_unique<HloCollectivePermuteInstruction>( + shape, operand, source_target_pairs); +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -1154,6 +1173,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1622,6 +1642,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -2275,6 +2296,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleCrossReplicaSum(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); + case HloOpcode::kCollectivePermute: + return visitor->HandleCollectivePermute(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -3189,6 +3212,11 @@ const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { return Cast<HloCollectiveInstruction>(this)->replica_groups(); } +const std::vector<std::pair<int64, int64>>& +HloInstruction::source_target_pairs() const { + return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs(); +} + string HloInstruction::cross_replica_sum_barrier() const { return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 01437f66cd..b393635e9d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -470,6 +470,15 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, const std::vector<ReplicaGroup>& replica_groups); + // Creates a communitation instructions that permutes data cross replicas. + // Data is sent/received according to the (source_replica_id, + // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a + // target_replica_id in any pair, the output on that replica is a tensor + // conssits of 0(s) in `shape`. + static std::unique_ptr<HloInstruction> CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, @@ -1429,9 +1438,12 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllToAllInstruction::replica_groups. + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector<ReplicaGroup>& replica_groups() const; + // Delegates to HloCollectivePermuteInstruction::source_target_pairs. + const std::vector<std::pair<int64, int64>>& source_target_pairs() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index b407cfeb50..b93c758937 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -416,6 +416,58 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( replica_groups()); } +HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs) + : HloInstruction(HloOpcode::kCollectivePermute, shape), + source_target_pairs_(source_target_pairs) { + AppendOperand(operand); +} + +HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (const auto& pair : source_target_pairs()) { + auto* proto_pair = proto.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + return proto; +} + +std::vector<string> +HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector<string> result; + std::vector<string> strs; + for (const auto& pair : source_target_pairs()) { + strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); + } + result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}")); + return result; +} + +bool HloCollectivePermuteInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + /*eq_computations*/) const { + const auto& casted_other = + static_cast<const HloCollectivePermuteInstruction&>(other); + return ContainersEqual( + source_target_pairs(), casted_other.source_target_pairs(), + [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) { + return a == b; + }); +} + +std::unique_ptr<HloInstruction> +HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique<HloCollectivePermuteInstruction>( + shape, new_operands[0], source_target_pairs()); +} + HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index efdb9e9781..29b187300d 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -290,7 +290,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand, + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, const std::vector<ReplicaGroup>& replica_groups); private: @@ -301,6 +301,36 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { HloCloneContext* context) const override; }; +class HloCollectivePermuteInstruction : public HloInstruction { + public: + explicit HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); + + const std::vector<std::pair<int64, int64>>& source_target_pairs() const { + return source_target_pairs_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + const std::vector<std::pair<int64, int64>> source_target_pairs_; +}; + class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index b8f2a21ff9..e6bfb8025d 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kCollectivePermute, "collective-permute") \ V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index e4edb87aa5..c7a766f4e0 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -702,6 +702,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateAllToAll(shape, operands, replica_groups)); break; } + case HloOpcode::kCollectivePermute: { + optional<std::vector<std::vector<int64>>> source_targets; + attrs["source_target_pairs"] = { + /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + std::vector<std::pair<int64, int64>> pairs(source_targets->size()); + for (int i = 0; i < pairs.size(); i++) { + if ((*source_targets)[i].size() != 2) { + return TokenError( + "expects 'source_target_pairs=' to be a list of pairs"); + } + pairs[i].first = (*source_targets)[i][0]; + pairs[i].second = (*source_targets)[i][1]; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index b3d3ccda74..b1ef288b8e 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1098,6 +1098,18 @@ ENTRY AllToAllWithSubgroups { )" }, +// collective-permute +{ +"CollectivePermute", +R"(HloModule CollectivePermute + +ENTRY CollectivePermute { + input = f32[128,32]{0,1} parameter(0) + ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} +} + +)" +}, // Iota { "Iota", diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 81ffb5ac43..0ed2c3b449 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -116,6 +116,11 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( + hlo->operand(0)->shape())); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index b6093d667c..42e3027bf1 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -47,6 +47,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; + Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 6207cdfb0d..83313c7ec1 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -122,6 +122,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index b04d2a7ba6..a04af8b0aa 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1844,6 +1844,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape( + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsArray(shape)); + return shape; +} + /* static */ StatusOr<Shape> ShapeInference::InferReduceShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 4974ac9916..235b1a4cf3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -136,6 +136,9 @@ class ShapeInference { static StatusOr<Shape> InferAllToAllTupleShape( tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + // Infers the shape of a collective permute operation. + static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 9451e0c315..aaba5aa92e 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -570,6 +570,12 @@ message ReplicaGroup { repeated int64 replica_ids = 1; } +// Describes the source target pair in the collective permute op. +message SourceTarget { + int64 source = 1; + int64 target = 2; +} + // Used to indicate the precision configuration. It has backend specific // meaning. message PrecisionConfigProto { |