aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc112
1 files changed, 81 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 7067b6f86a..fedc83c8f8 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -175,41 +175,32 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
TF_RETURN_IF_ERROR(
LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
- const BufferLayoutConstraint* curr_constraint =
- GetBufferLayoutConstraint(buffer);
- if (curr_constraint != nullptr) {
- if (LayoutUtil::Equal(curr_constraint->layout(), layout)) {
+ auto iter = buffer_constraints_.find(&buffer);
+ if (iter != buffer_constraints_.end()) {
+ const BufferLayoutConstraint& curr_constraint = iter->second;
+ if (LayoutUtil::Equal(curr_constraint.layout(), layout)) {
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
}
- if (curr_constraint->mandatory()) {
+ if (curr_constraint.mandatory()) {
return FailedPrecondition(
"Buffer %s already has the layout constraint %s, cannot add "
"incompatible constraint %s",
buffer.ToString().c_str(),
- LayoutUtil::HumanString(curr_constraint->layout()).c_str(),
+ LayoutUtil::HumanString(curr_constraint.layout()).c_str(),
LayoutUtil::HumanString(layout).c_str());
}
- }
-
- auto iter = buffer_constraints_.find(&buffer);
- bool overwrite = iter != buffer_constraints_.end();
- if (!overwrite) {
+ iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
+ } else {
+ TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
+ << buffer.ToString();
iter = buffer_constraints_
.insert(std::make_pair(
&buffer,
BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
.first;
- } else {
- iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
}
added_constraints_.push_back(&iter->second);
-
- // Remove buffer from the set of unconstrained buffers.
- TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) ==
- static_cast<int>(!overwrite));
- unconstrained_buffer_ids_.erase(buffer.id());
-
return Status::OK();
}
@@ -716,7 +707,8 @@ Status CheckParameterLayout(HloInstruction* parameter,
const ComputationLayout& computation_layout) {
const ShapeLayout& parameter_layout =
computation_layout.parameter_layout(parameter->parameter_number());
- if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) {
+ if (parameter_layout.LayoutIsSet() &&
+ !parameter_layout.MatchesLayoutInShape(parameter->shape())) {
return InternalError(
"parameter instruction %s does not match layout of computation "
"shape: %s",
@@ -936,14 +928,15 @@ LayoutAssignment::LayoutAssignment(
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
+ saved_entry_computation_layout_(*entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
+ if (channel_layout_constraints_ != nullptr) {
+ // Save a copy of the input ChannelLayoutConstraints so that we can reset it
+ // if we have to undo previous operations (ClearPreviousPassSideEffects()).
+ channel_constraints_ = *channel_layout_constraints_;
+ }
VLOG(1) << "Entry computation layout given to layout assignment: "
<< entry_computation_layout_->ToString();
- // Layouts of all parameter instructions must be set.
- for (const ShapeLayout& parameter_layout :
- entry_computation_layout_->parameter_layouts()) {
- CHECK(parameter_layout.LayoutIsSet());
- }
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1572,6 +1565,13 @@ Status LayoutAssignment::RunOnComputation(
// Propagates layouts from mandatory and backend constraints.
TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
+ // Prior to applying default layouts, we take note of all HLO instructions
+ // which lack a layout constraint.
+ for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
+ unconstrained_layout_instructions_.insert(
+ points_to_analysis.GetBuffer(buffer_id).instruction());
+ }
+
// While any unconstrained buffers remain, pick an arbitrary buffer, give it a
// layout and propagate the change.
while (!constraints.unconstrained_buffer_ids().empty()) {
@@ -1614,13 +1614,58 @@ Status LayoutAssignment::RunOnComputation(
// Record the layouts assigned for any communication ops in
// channel_constraints so that they are constrained for future modules.
+ if (channel_constraints != nullptr) {
+ TF_RETURN_IF_ERROR(
+ ConstrainChannelLayouts(computation, channel_constraints));
+ }
+ return Status::OK();
+}
+
+Status LayoutAssignment::ConstrainChannelLayouts(
+ HloComputation* computation,
+ ChannelLayoutConstraints* channel_constraints) {
+ // We go through the kRecvDone before. These must either impose their layout,
+ // of find a matching one already existing (ConstrainChannel() returns
+ // nullptr).
for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kRecvDone) {
+ const Layout* layout = channel_constraints->ConstrainChannel(
+ instruction->channel_id(),
+ ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
+ TF_RET_CHECK(layout == nullptr)
+ << instruction->ToString()
+ << " cannot constrain layout as it was set to "
+ << LayoutUtil::HumanString(*layout);
+ }
+ }
+ // After that we go through the kSend. These are likely going to have a kCopy
+ // as operand (otherwise we add it), so in case the constrained layout does
+ // not match, we can change the kCopy layout (and the kSend one as well).
+ for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kSend) {
- channel_constraints->ConstrainChannel(
- instruction->channel_id(), instruction->operand(0)->shape().layout());
- } else if (instruction->opcode() == HloOpcode::kRecvDone) {
- channel_constraints->ConstrainChannel(instruction->channel_id(),
- instruction->shape().layout());
+ HloInstruction* operand = instruction->mutable_operand(0);
+ const Layout* layout = channel_constraints->ConstrainChannel(
+ instruction->channel_id(), operand->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the kSend wants to impose. Either add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ Shape* send_shape =
+ ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
+ *send_shape = shape;
+ }
}
}
return Status::OK();
@@ -1672,13 +1717,14 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
// when seen from an outer instruction, which has across-computation
// constraints to impose.
// For example, the kWhile instruction needs to enforce the same layouts for
- // the parameters and root of the bosy, as well as the condition parameters.
+ // the parameters and root of the body, as well as the condition parameters.
// Similarly, the kConditional instruction needs to enforce the same layouts
// for the root of the true and false computations.
// So in the first pass, while allowing the layouts to flow to parameters and
// root, we also fix up the eventually inconsistent ComputationLayout, which
// will be then made mandatory by the second pass.
for (int64 i = 0; i < 2; ++i) {
+ VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
@@ -1716,10 +1762,12 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
Status LayoutAssignment::Init() {
computation_layouts_.clear();
+ *entry_computation_layout_ = saved_entry_computation_layout_;
return Status::OK();
}
Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
+ VLOG(5) << "Clearing previous side effects";
// Clear all the copies which have been added, and all the related
// instructions (like GTE and tuples).
int64 removed_copies = 0;
@@ -1737,12 +1785,14 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
}
}
added_copies_.clear();
+ unconstrained_layout_instructions_.clear();
if (removed_copies > 0) {
TupleSimplifier tuple_simplifier;
HloDCE dce;
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
+ ResetChannelConstraints();
return Status::OK();
}