diff options
author | 2018-10-02 03:01:09 -0700 | |
---|---|---|
committer | 2018-10-02 03:05:37 -0700 | |
commit | 44da41e4900c3fd481f12c9aa4c49679c9f32fa4 (patch) | |
tree | 8aeb076e3ff0cdfd3d9e5e50661b2963e6b38ac1 /tensorflow/compiler/xla/service/layout_assignment.cc | |
parent | edea1be5dd98775399dbd12728e86039a14fb967 (diff) |
Fix layout assignment for cross module all reduce
Previously we could have ended up with the different HLOs being assigned
different layouts what made lowering impossible. This change enforces a
consistent layout between the communicating nodes the same way it is
done for send&recv pairs.
PiperOrigin-RevId: 215359420
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/layout_assignment.cc | 65 |
1 files changed, 52 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 082bf8bffe..25d5327561 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } + } else if (instruction->IsCrossModuleAllReduce()) { + CHECK(get_channel_constraints(instruction)) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 all_reduce_id = instruction->all_reduce_id().value(); + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(all_reduce_id)) { + continue; + } + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape& buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(buffer_shape, all_reduce_id); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(new_buffer_shape, instruction)); } } @@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - - // Copy the root instruction's result if its layout does not match the result - // layout constraint. - if (constraints.ResultLayout() != nullptr && - !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), - computation->root_instruction())); - computation->set_root_instruction(new_root); - } - return Status::OK(); } @@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation( TF_RETURN_IF_ERROR( ConstrainChannelLayouts(computation, channel_constraints)); } + + // Copy the root instruction's result if its layout does not match the result + // layout constraint. + if (constraints.ResultLayout() != nullptr && + !constraints.ResultLayout()->MatchesLayoutInShape( + computation->root_instruction()->shape())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_root, + CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + computation->root_instruction())); + computation->set_root_instruction(new_root); + } return Status::OK(); } @@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts( ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); *send_shape = shape; } + } else if (instruction->IsCrossModuleAllReduce()) { + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel(instruction->all_reduce_id().value(), + instruction->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the channel wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + HloInstruction* operand = instruction->mutable_operand(0); + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + *instruction->mutable_shape() = shape; + } } } return Status::OK(); |