aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 03:01:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 03:05:37 -0700
commit44da41e4900c3fd481f12c9aa4c49679c9f32fa4 (patch)
tree8aeb076e3ff0cdfd3d9e5e50661b2963e6b38ac1 /tensorflow/compiler/xla/service/layout_assignment.cc
parentedea1be5dd98775399dbd12728e86039a14fb967 (diff)
Fix layout assignment for cross module all reduce
Previously we could have ended up with the different HLOs being assigned different layouts what made lowering impossible. This change enforces a consistent layout between the communicating nodes the same way it is done for send&recv pairs. PiperOrigin-RevId: 215359420
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc65
1 files changed, 52 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 082bf8bffe..25d5327561 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ CHECK(get_channel_constraints(instruction))
+ << "Multi-module layout assignment requires ChannelLayoutConstraints";
+ int64 all_reduce_id = instruction->all_reduce_id().value();
+ if (!get_channel_constraints(instruction)
+ ->IsChannelConstrained(all_reduce_id)) {
+ continue;
+ }
+ // TODO(b/68493863): Change to use SetOperandLayout().
+ const Shape& buffer_shape = instruction->operand(0)->shape();
+ TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape));
+ Shape new_buffer_shape =
+ get_channel_constraints(instruction)
+ ->LayoutShapeForChannel(buffer_shape, all_reduce_id);
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
}
@@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
-
- // Copy the root instruction's result if its layout does not match the result
- // layout constraint.
- if (constraints.ResultLayout() != nullptr &&
- !constraints.ResultLayout()->MatchesLayoutInShape(
- computation->root_instruction()->shape())) {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_root,
- CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
- computation->root_instruction()));
- computation->set_root_instruction(new_root);
- }
-
return Status::OK();
}
@@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation(
TF_RETURN_IF_ERROR(
ConstrainChannelLayouts(computation, channel_constraints));
}
+
+ // Copy the root instruction's result if its layout does not match the result
+ // layout constraint.
+ if (constraints.ResultLayout() != nullptr &&
+ !constraints.ResultLayout()->MatchesLayoutInShape(
+ computation->root_instruction()->shape())) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_root,
+ CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
+ computation->root_instruction()));
+ computation->set_root_instruction(new_root);
+ }
return Status::OK();
}
@@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts(
ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
*send_shape = shape;
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ const Layout* layout =
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(instruction->all_reduce_id().value(),
+ instruction->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the channel wants to impose. Either add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ HloInstruction* operand = instruction->mutable_operand(0);
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ *instruction->mutable_shape() = shape;
+ }
}
}
return Status::OK();