aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc65
1 files changed, 52 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 082bf8bffe..25d5327561 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ CHECK(get_channel_constraints(instruction))
+ << "Multi-module layout assignment requires ChannelLayoutConstraints";
+ int64 all_reduce_id = instruction->all_reduce_id().value();
+ if (!get_channel_constraints(instruction)
+ ->IsChannelConstrained(all_reduce_id)) {
+ continue;
+ }
+ // TODO(b/68493863): Change to use SetOperandLayout().
+ const Shape& buffer_shape = instruction->operand(0)->shape();
+ TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape));
+ Shape new_buffer_shape =
+ get_channel_constraints(instruction)
+ ->LayoutShapeForChannel(buffer_shape, all_reduce_id);
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
}
@@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
-
- // Copy the root instruction's result if its layout does not match the result
- // layout constraint.
- if (constraints.ResultLayout() != nullptr &&
- !constraints.ResultLayout()->MatchesLayoutInShape(
- computation->root_instruction()->shape())) {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_root,
- CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
- computation->root_instruction()));
- computation->set_root_instruction(new_root);
- }
-
return Status::OK();
}
@@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation(
TF_RETURN_IF_ERROR(
ConstrainChannelLayouts(computation, channel_constraints));
}
+
+ // Copy the root instruction's result if its layout does not match the result
+ // layout constraint.
+ if (constraints.ResultLayout() != nullptr &&
+ !constraints.ResultLayout()->MatchesLayoutInShape(
+ computation->root_instruction()->shape())) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_root,
+ CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
+ computation->root_instruction()));
+ computation->set_root_instruction(new_root);
+ }
return Status::OK();
}
@@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts(
ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
*send_shape = shape;
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ const Layout* layout =
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(instruction->all_reduce_id().value(),
+ instruction->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the channel wants to impose. Either add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ HloInstruction* operand = instruction->mutable_operand(0);
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ *instruction->mutable_shape() = shape;
+ }
}
}
return Status::OK();