aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc65
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc44
2 files changed, 96 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 082bf8bffe..25d5327561 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ CHECK(get_channel_constraints(instruction))
+ << "Multi-module layout assignment requires ChannelLayoutConstraints";
+ int64 all_reduce_id = instruction->all_reduce_id().value();
+ if (!get_channel_constraints(instruction)
+ ->IsChannelConstrained(all_reduce_id)) {
+ continue;
+ }
+ // TODO(b/68493863): Change to use SetOperandLayout().
+ const Shape& buffer_shape = instruction->operand(0)->shape();
+ TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape));
+ Shape new_buffer_shape =
+ get_channel_constraints(instruction)
+ ->LayoutShapeForChannel(buffer_shape, all_reduce_id);
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
}
@@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
-
- // Copy the root instruction's result if its layout does not match the result
- // layout constraint.
- if (constraints.ResultLayout() != nullptr &&
- !constraints.ResultLayout()->MatchesLayoutInShape(
- computation->root_instruction()->shape())) {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_root,
- CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
- computation->root_instruction()));
- computation->set_root_instruction(new_root);
- }
-
return Status::OK();
}
@@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation(
TF_RETURN_IF_ERROR(
ConstrainChannelLayouts(computation, channel_constraints));
}
+
+ // Copy the root instruction's result if its layout does not match the result
+ // layout constraint.
+ if (constraints.ResultLayout() != nullptr &&
+ !constraints.ResultLayout()->MatchesLayoutInShape(
+ computation->root_instruction()->shape())) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_root,
+ CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
+ computation->root_instruction()));
+ computation->set_root_instruction(new_root);
+ }
return Status::OK();
}
@@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts(
ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
*send_shape = shape;
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ const Layout* layout =
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(instruction->all_reduce_id().value(),
+ instruction->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the channel wants to impose. Either add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ HloInstruction* operand = instruction->mutable_operand(0);
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ *instruction->mutable_shape() = shape;
+ }
}
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 752a61476d..10f9a95121 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -860,6 +860,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
+TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
+ // Pin non matching layouts to parameter and root.
+ const char* module_str = R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry_computation {
+ param = (f32[2,2]) parameter(0)
+ gte = f32[2,2] get-tuple-element(param), index=0
+ ar.0 = f32[2,2] cross-replica-sum(gte),
+ all_reduce_id=0, replica_groups={{0}}, to_apply=add,
+ sharding={maximal device=0}
+ const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}})
+ ROOT ar.1 = f32[2,2] cross-replica-sum(const),
+ all_reduce_id=0, replica_groups={{0}}, to_apply=add,
+ sharding={maximal device=1}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ Shape param_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
+ TF_ASSERT_OK(
+ computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
+ param_shape));
+ computation_layout.mutable_result_layout()->ResetLayout(
+ LayoutUtil::MakeLayout({1, 0}));
+
+ ChannelLayoutConstraints channel_constraints;
+ AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+
+ EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1));
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
+}
+
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
const char* module_str = R"(
HloModule CopySliceOperandToAvoidImplicitLayoutChange