aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2018-02-03 23:45:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-03 23:48:56 -0800
commitde6037d7505cd24b24fd937d1c011c8d24ca0e81 (patch)
tree482fde94e1c05c3d723dfcd2476c1c44cbad5269 /tensorflow/compiler/xla/service/layout_assignment.cc
parenta42450a76e43154cc3bf8977c2e9c8afb1d08621 (diff)
[XLA] Assign mandatory constraints in a DFS order and non-manatory constraints in a BFS order.
PiperOrigin-RevId: 184429818
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc70
1 files changed, 39 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 5413b95cfb..fce135ef61 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -61,8 +61,8 @@ std::ostream& operator<<(std::ostream& out,
BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
const LogicalBuffer& buffer,
- bool mandatory)
- : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) {
+ bool mandatory, bool dfs)
+ : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
}
@@ -74,8 +74,8 @@ string BufferLayoutConstraint::ToString() const {
OperandLayoutConstraint::OperandLayoutConstraint(
const ShapeLayout& shape_layout, const HloInstruction* instruction,
- int64 operand_no, bool mandatory)
- : LayoutConstraint(mandatory),
+ int64 operand_no, bool mandatory, bool dfs)
+ : LayoutConstraint(mandatory, dfs),
shape_layout_(shape_layout),
instruction_(instruction),
operand_no_(operand_no) {
@@ -134,7 +134,7 @@ bool LayoutConstraints::OperandBufferForwarded(
Status LayoutConstraints::SetBufferLayout(const Layout& layout,
const LogicalBuffer& buffer,
- bool mandatory) {
+ bool mandatory, bool dfs) {
VLOG(3) << "SetBufferLayout : " << buffer << " : "
<< LayoutUtil::HumanString(layout);
@@ -171,10 +171,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
if (!overwrite) {
iter = buffer_constraints_
.insert(std::make_pair(
- &buffer, BufferLayoutConstraint(layout, buffer, mandatory)))
+ &buffer,
+ BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
.first;
} else {
- iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true);
+ iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
}
added_constraints_.push_back(&iter->second);
@@ -188,7 +189,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
const HloInstruction* instruction,
- int64 operand_no, bool mandatory) {
+ int64 operand_no, bool mandatory,
+ bool dfs) {
VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
<< operand_no << " : "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
@@ -226,12 +228,12 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
if (iter == operand_constraints_.end()) {
auto pair = std::make_pair(
key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
- instruction, operand_no, mandatory));
+ instruction, operand_no, mandatory, dfs));
iter = operand_constraints_.insert(pair).first;
} else {
iter->second =
OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
- operand_no, /*mandatory=*/true);
+ operand_no, mandatory, dfs);
}
added_constraints_.push_back(&iter->second);
@@ -240,16 +242,17 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
Status LayoutConstraints::SetArrayOperandLayout(
const Layout& layout, const HloInstruction* instruction, int64 operand_no,
- bool mandatory) {
+ bool mandatory, bool dfs) {
const HloInstruction* operand = instruction->operand(operand_no);
TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
Shape shape(operand->shape());
*shape.mutable_layout() = layout;
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
- return SetOperandLayout(shape, instruction, operand_no, mandatory);
+ return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
}
-Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) {
+Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
+ bool dfs) {
VLOG(3) << "SetResultLayout : "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
@@ -267,14 +270,15 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) {
}
result_constraint_.reset(
- new ResultLayoutConstraint(ShapeLayout(shape_with_layout)));
+ new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
added_constraints_.push_back(result_constraint_.get());
return Status::OK();
}
Status LayoutConstraints::SetInstructionLayout(
- const Shape& shape_with_layout, const HloInstruction* instruction) {
+ const Shape& shape_with_layout, const HloInstruction* instruction,
+ bool mandatory, bool dfs) {
VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
<< ShapeUtil::HumanStringWithLayout(shape_with_layout);
@@ -290,8 +294,8 @@ Status LayoutConstraints::SetInstructionLayout(
// instruction.
return ShapeUtil::ForEachSubshapeWithStatus(
shape_with_layout,
- [this, instruction](const Shape& subshape,
- const ShapeIndex& index) -> Status {
+ [this, instruction, mandatory](const Shape& subshape,
+ const ShapeIndex& index) -> Status {
// The precondition for this method is that the instruction defines all
// buffers in its output.
auto buffers =
@@ -300,7 +304,7 @@ Status LayoutConstraints::SetInstructionLayout(
CHECK_EQ(buffers[0]->instruction(), instruction);
if (ShapeUtil::IsArray(subshape)) {
- return SetBufferLayout(subshape.layout(), *buffers[0]);
+ return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
} else {
return Status::OK();
}
@@ -394,8 +398,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
// Constrain the input to the Outfeed instruction to be the expected
// layout of the Outfeed.
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- instruction->outfeed_shape(), instruction, 0,
- /*mandatory=*/true));
+ instruction->outfeed_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kParameter) {
// Parameter layouts must match the respective layout in
// ComputationLayout.
@@ -434,8 +437,8 @@ Status LayoutAssignment::AddMandatoryConstraints(
{0}));
Shape new_shape = channel_constraints->LayoutShapeForChannel(
recv_buffer_shape, instruction->channel_id());
- TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
- new_shape.layout(), *buffer, /*mandatory=*/true));
+ TF_RETURN_IF_ERROR(
+ constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
}
}
@@ -457,7 +460,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
for (int64 i = 0; i < instruction->operand_count(); ++i) {
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
called_computation_layout.parameter_layout(i).shape(), instruction,
- i, /*mandatory=*/true));
+ i));
}
} else if (instruction->opcode() == HloOpcode::kWhile) {
// Layout of input and output of kWhile instruction must be equal and must
@@ -508,8 +511,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
body_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- body_layout.result_shape(), instruction, 0,
- /*mandatory=*/true));
+ body_layout.result_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kCustomCall) {
if (!CustomCallRequiresMajorFirstLayout(instruction)) {
continue;
@@ -533,7 +535,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
operand_shape.element_type(),
AsInt64Slice(operand_shape.dimensions()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- row_major_operand_shape, instruction, i, /*mandatory=*/true));
+ row_major_operand_shape, instruction, i));
}
}
}
@@ -907,7 +909,11 @@ Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
auto add_new_constraints_to_worklist = [constraints, &worklist]() {
// Add constraints to the front of the deque for DFS ordering.
for (auto* constraint : constraints->ConsumeAddedConstraints()) {
- worklist.push_front(constraint);
+ if (constraint->dfs()) {
+ worklist.push_front(constraint);
+ } else {
+ worklist.push_back(constraint);
+ }
}
};
add_new_constraints_to_worklist();
@@ -1390,7 +1396,7 @@ Status LayoutAssignment::RunOnComputation(
// Add any backend-specific constraints.
TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
- // Propagates layouts from an HLO to its neighbors.
+ // Propagates layouts from mandatory and backend constraints.
TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
// While any unconstrained buffers remain, pick an arbitrary buffer, give it a
@@ -1455,7 +1461,12 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
// Assign layouts to computations in an order such that a callee computation
// is handled before its caller computation. This ensures that the layout of
// all callers of a computation will agree.
+ std::list<HloComputation*> computation_post_order =
+ module->MakeComputationPostOrder();
for (auto* computation : module->MakeComputationPostOrder()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
// Clear existing layouts of the instructions. All layouts must be assigned
// by the LayoutAssignment pass, except for those on infeeds, parameters,
// and the computation result. The latter two are specified in
@@ -1467,13 +1478,10 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}
-
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(
*entry_computation_layout_, *points_to_analysis,
module->entry_computation(), channel_layout_constraints_));
- } else if (computation->IsFusionComputation()) {
- continue;
} else {
ComputationLayout computation_layout(computation->ComputeProgramShape());
// Setting all embedded computations to the default layout is potentially