aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc1334
1 files changed, 1334 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
new file mode 100644
index 0000000000..a8f2a6b89c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -0,0 +1,1334 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
+
+#include <algorithm>
+#include <deque>
+#include <functional>
+#include <map>
+#include <memory>
+#include <numeric>
+#include <ostream>
+#include <set>
+#include <string>
+#include <tuple>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+
+std::ostream& operator<<(std::ostream& out,
+ const LayoutConstraint& constraint) {
+ out << constraint.ToString();
+ return out;
+}
+
+BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
+ const LogicalBuffer& buffer)
+ : layout_(layout), buffer_(&buffer) {
+ CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
+}
+
+string BufferLayoutConstraint::ToString() const {
+ return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s",
+ buffer_->ToString().c_str(),
+ LayoutUtil::HumanString(layout_).c_str());
+}
+
+OperandLayoutConstraint::OperandLayoutConstraint(
+ const ShapeLayout& shape_layout, const HloInstruction* instruction,
+ int64 operand_no)
+ : shape_layout_(shape_layout),
+ instruction_(instruction),
+ operand_no_(operand_no) {
+ CHECK(shape_layout_.LayoutIsSet());
+ CHECK(ShapeUtil::Compatible(shape_layout.shape(),
+ instruction->operand(operand_no)->shape()));
+}
+
+string OperandLayoutConstraint::ToString() const {
+ return tensorflow::strings::Printf(
+ "OperandLayoutConstraint %s, operand %lld: %s",
+ instruction_->name().c_str(), operand_no_,
+ shape_layout_.ToString().c_str());
+}
+
+string ResultLayoutConstraint::ToString() const {
+ return tensorflow::strings::Printf("ResultLayoutConstraint: %s",
+ shape_layout_.ToString().c_str());
+}
+
+LayoutConstraints::LayoutConstraints(
+ const TuplePointsToAnalysis& points_to_analysis,
+ const HloComputation* computation)
+ : points_to_analysis_(points_to_analysis), computation_(computation) {
+ // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
+ for (auto& buffer : points_to_analysis_.logical_buffers()) {
+ if (buffer->IsArray()) {
+ unconstrained_buffer_ids_.insert(buffer->id());
+ }
+ }
+}
+
+bool LayoutConstraints::OperandBufferForwarded(
+ const HloInstruction* instruction, int64 operand_no) const {
+ // The operand is potentially forwarded if the intersection of points-to sets
+ // of the operand and the instruction is non-empty.
+ auto output_buffers =
+ points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet();
+ auto operand_buffers =
+ points_to_analysis_.GetPointsToSet(instruction->operand(operand_no))
+ .CreateFlattenedSet();
+ std::vector<const LogicalBuffer*> intersection;
+ std::set_intersection(output_buffers.begin(), output_buffers.end(),
+ operand_buffers.begin(), operand_buffers.end(),
+ std::back_inserter(intersection));
+ return !intersection.empty();
+}
+
+Status LayoutConstraints::SetBufferLayout(const Layout& layout,
+ const LogicalBuffer& buffer) {
+ VLOG(3) << "SetBufferLayout : " << buffer << " : "
+ << LayoutUtil::HumanString(layout);
+
+ TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
+ if (!buffer.IsArray()) {
+ return FailedPrecondition(
+ "Layout of buffer %s cannot be constrained because buffer is not "
+ "array-shaped, has shape: %s",
+ buffer.ToString().c_str(),
+ ShapeUtil::HumanString(buffer.shape()).c_str());
+ }
+ TF_RETURN_IF_ERROR(
+ LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
+
+ const Layout* curr_layout = BufferLayout(buffer);
+ if (curr_layout != nullptr) {
+ if (!LayoutUtil::Equal(*curr_layout, layout)) {
+ return FailedPrecondition(
+ "Buffer %s already has the layout constraint %s, cannot add "
+ "incompatible constraint %s",
+ buffer.ToString().c_str(),
+ LayoutUtil::HumanString(*curr_layout).c_str(),
+ LayoutUtil::HumanString(layout).c_str());
+ }
+ // New constraint matches existing constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ auto new_constraint_it = buffer_constraints_.insert(
+ {&buffer, BufferLayoutConstraint(layout, buffer)});
+ added_constraints_.push_back(&new_constraint_it.first->second);
+
+ // Remove buffer from the set of unconstrained buffers.
+ TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == 1);
+ unconstrained_buffer_ids_.erase(buffer.id());
+
+ return Status::OK();
+}
+
+Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
+ const HloInstruction* instruction,
+ int64 operand_no) {
+ VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
+ << operand_no << " : "
+ << ShapeUtil::HumanStringWithLayout(shape_with_layout);
+
+ const ShapeLayout* curr_shape_layout = OperandLayout(instruction, operand_no);
+ if (curr_shape_layout != nullptr) {
+ if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
+ return FailedPrecondition(
+ "Operand %lld of instruction %s already has a layout constraint "
+ "%s, cannot add incompatible constraint %s",
+ operand_no, instruction->name().c_str(),
+ curr_shape_layout->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ }
+ // New constraint matches existing constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ // If any buffers in the operand occur in the output of the instruction, then
+ // return an error. This case is not handled because such a constraint changes
+ // layouts beyond this immediate use and is complicated to handle.
+ if (OperandBufferForwarded(instruction, operand_no)) {
+ return FailedPrecondition(
+ "Cannot constraint layout of operand %lld of instruction %s "
+ "because instruction forwards operand's LogicalBuffer(s)",
+ operand_no, instruction->name().c_str());
+ }
+
+ auto key = std::make_pair(instruction, operand_no);
+ auto new_constraint_it = operand_constraints_.insert(
+ {key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
+ operand_no)});
+ added_constraints_.push_back(&new_constraint_it.first->second);
+
+ return Status::OK();
+}
+
+Status LayoutConstraints::SetArrayOperandLayout(
+ const Layout& layout, const HloInstruction* instruction, int64 operand_no) {
+ const HloInstruction* operand = instruction->operand(operand_no);
+ TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
+ Shape shape(operand->shape());
+ *shape.mutable_layout() = layout;
+ TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
+ return SetOperandLayout(shape, instruction, operand_no);
+}
+
+Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) {
+ VLOG(3) << "SetResultLayout : "
+ << ShapeUtil::HumanStringWithLayout(shape_with_layout);
+
+ const ShapeLayout* curr_shape_layout = ResultLayout();
+ if (curr_shape_layout != nullptr) {
+ if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
+ return FailedPrecondition(
+ "Result of computation %s already has the layout constraint %s, "
+ "cannot add incompatible constraint %s",
+ computation_->name().c_str(), curr_shape_layout->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ }
+ // New constraint matches existing constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ result_constraint_.reset(
+ new ResultLayoutConstraint(ShapeLayout(shape_with_layout)));
+ added_constraints_.push_back(result_constraint_.get());
+
+ return Status::OK();
+}
+
+Status LayoutConstraints::SetInstructionLayout(
+ const Shape& shape_with_layout, const HloInstruction* instruction) {
+ VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
+ << ShapeUtil::HumanStringWithLayout(shape_with_layout);
+
+ if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
+ return FailedPrecondition(
+ "Instruction %s of shape %s cannot be assigned incompatible layout %s",
+ instruction->name().c_str(),
+ ShapeUtil::HumanString(instruction->shape()).c_str(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ }
+
+ // Create a BufferLayoutConstraint for each array shape in the output of the
+ // instruction.
+ return ShapeUtil::ForEachSubshape(
+ shape_with_layout,
+ [this, instruction](const Shape& subshape,
+ const ShapeIndex& index) -> Status {
+ // The precondition for this method is that the instruction defines all
+ // buffers in its output.
+ auto buffers =
+ points_to_analysis_.GetPointsToSet(instruction).element(index);
+ CHECK_EQ(1, buffers.size());
+ CHECK_EQ(buffers[0]->instruction(), instruction);
+
+ if (ShapeUtil::IsArray(subshape)) {
+ return SetBufferLayout(subshape.layout(), *buffers[0]);
+ } else {
+ return Status::OK();
+ }
+ });
+}
+
+const Layout* LayoutConstraints::BufferLayout(
+ const LogicalBuffer& buffer) const {
+ auto it = buffer_constraints_.find(&buffer);
+ return it == buffer_constraints_.end() ? nullptr : &it->second.layout();
+}
+
+const ShapeLayout* LayoutConstraints::OperandLayout(
+ const HloInstruction* instruction, int64 operand_no) const {
+ auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
+ return it == operand_constraints_.end() ? nullptr
+ : &it->second.shape_layout();
+}
+
+const ShapeLayout* LayoutConstraints::ResultLayout() const {
+ return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
+}
+
+string LayoutConstraints::ToString() const {
+ string output;
+ tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ",
+ computation_->name(), ":\n");
+ for (auto* instruction : computation_->MakeInstructionPostOrder()) {
+ tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(),
+ "\n");
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ if (OperandLayout(instruction, i) != nullptr) {
+ tensorflow::strings::StrAppend(
+ &output, " operand (", i, "): ",
+ OperandLayout(instruction, i)->ToString(), "\n");
+ }
+ }
+ for (const LogicalBuffer* buffer :
+ points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
+ if (BufferLayout(*buffer) != nullptr) {
+ tensorflow::strings::StrAppend(
+ &output, " ", buffer->ToString(), " : ",
+ LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
+ }
+ }
+ }
+
+ if (ResultLayout() != nullptr) {
+ tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(),
+ "\n");
+ }
+ return output;
+}
+
+Status LayoutAssignment::AddMandatoryConstraints(
+ const ComputationLayout& computation_layout, HloComputation* computation,
+ LayoutConstraints* constraints) {
+ VLOG(3) << "Adding mandatory layout constraints to computation "
+ << computation->name();
+
+ // Constrain layouts of instructions which define values with pre-existing
+ // layouts.
+ for (auto& instruction : computation->instructions()) {
+ Shape const* shape_with_layout = nullptr;
+ if (instruction->opcode() == HloOpcode::kConstant) {
+ // Constant layouts must match the layout of their literal.
+ shape_with_layout = &instruction->literal().shape();
+ } else if (instruction->opcode() == HloOpcode::kInfeed) {
+ // Infeed layouts must match the layout of the original inserted
+ // instruction.
+ // TODO(b/31425034): Change infeeds to be more like parameters, with
+ // shapes in the ComputationLayout.
+ shape_with_layout = &instruction->shape();
+ } else if (instruction->opcode() == HloOpcode::kParameter) {
+ // Parameter layouts must match the respective layout in
+ // ComputationLayout.
+ shape_with_layout =
+ &computation_layout.parameter_layout(instruction->parameter_number())
+ .shape();
+ }
+ if (shape_with_layout != nullptr) {
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(*shape_with_layout,
+ instruction.get()));
+ }
+ }
+
+ // Constrain layouts of instructions which call computations which have
+ // already been assigned layouts. Instructions which call computations in a
+ // parallel element-wise context (eg, map or reduce) do not need layout
+ // constraints because they operate on scalars.
+ for (auto& instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCall) {
+ // kCall instruction operands and output must match the ComputationLayout
+ // of the called computation.
+ const ComputationLayout& called_computation_layout =
+ FindOrDie(computation_layouts_, instruction->to_apply());
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ called_computation_layout.result_layout().shape(),
+ instruction.get()));
+ TF_RET_CHECK(instruction->operand_count() ==
+ called_computation_layout.parameter_count());
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ called_computation_layout.parameter_layout(i).shape(),
+ instruction.get(), i));
+ }
+ } else if (instruction->opcode() == HloOpcode::kWhile) {
+ // Layout of input and output of kWhile instruction must be equal and must
+ // match both input and output of body computation. Also, the input of
+ // condition computation must match kWhile layout.
+ HloComputation* body = instruction->while_body();
+ HloComputation* condition = instruction->while_condition();
+ const HloInstruction* init = instruction->operand(0);
+ const ComputationLayout& body_layout =
+ FindOrDie(computation_layouts_, body);
+ const ComputationLayout& condition_layout =
+ FindOrDie(computation_layouts_, condition);
+
+ // Check a few invariants irrespective of layout.
+ CHECK_EQ(1, instruction->operand_count());
+ CHECK_EQ(1, body->num_parameters());
+ CHECK_EQ(1, condition->num_parameters());
+ DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
+ body_layout.parameter_shape(0)));
+ DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
+ condition_layout.parameter_shape(0)));
+ DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
+
+ // Return error if earlier layout assignment of the embedded computations
+ // has produced conflicting layouts.
+ if (!ShapeUtil::Equal(body_layout.result_shape(),
+ body_layout.parameter_shape(0))) {
+ return InternalError(
+ "Parameter and result of body computation %s of while instruction "
+ "%s have different layouts: %s vs %s",
+ body->name().c_str(), instruction->name().c_str(),
+ ShapeUtil::HumanString(body_layout.result_shape()).c_str(),
+ ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str());
+ }
+ if (!ShapeUtil::Equal(body->root_instruction()->shape(),
+ condition->parameter_instruction(0)->shape())) {
+ return InternalError(
+ "Parameter of condition computation %s of while instruction "
+ "%s does not match body computation %s result: %s vs %s",
+ condition->name().c_str(), instruction->name().c_str(),
+ body->name().c_str(),
+ ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(),
+ ShapeUtil::HumanString(body_layout.result_shape()).c_str());
+ }
+
+ // Constrain the output and the operand of the while instruction to match
+ // the computations.
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ body_layout.result_shape(), instruction.get()));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ body_layout.result_shape(), instruction.get(), 0));
+ } else if (instruction->opcode() == HloOpcode::kCustomCall) {
+ // Add constraints for kCustomCall instruction operands and instructions.
+ // For now we only support row major layouts for all inputs and outputs.
+ auto row_major_shape = [](const Shape& old_shape) {
+ Shape new_shape(old_shape);
+ std::vector<int64> dimension_order(new_shape.dimensions_size());
+ std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
+ *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
+ return new_shape;
+ };
+
+ Shape result_shape(row_major_shape(instruction->shape()));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(result_shape, instruction.get()));
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ const Shape& operand_shape = instruction->operand(i)->shape();
+ // Opaque operands don't get a layout constraint.
+ if (ShapeUtil::IsOpaque(operand_shape)) {
+ continue;
+ }
+
+ Shape row_major_operand_shape(row_major_shape(operand_shape));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ row_major_operand_shape, instruction.get(), i));
+ }
+ }
+ }
+
+ // Finally set the result layout to match ComputationLayout.
+ return constraints->SetResultLayout(
+ computation_layout.result_layout().shape());
+}
+
+namespace {
+
+// The operands of a call must match the layouts of parameters in the
+// ComputationLayout, and the call instruction itself must match the result
+// layout in the ComputationLayout.
+Status CheckCallLayout(HloInstruction* call,
+ const ComputationLayout& computation_layout) {
+ HloComputation* computation = call->to_apply();
+ TF_RET_CHECK(computation->num_parameters() == call->operand_count());
+ for (int64 i = 0; i < computation->num_parameters(); ++i) {
+ TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ call->operand(i)->shape()));
+ }
+ TF_RET_CHECK(
+ computation_layout.result_layout().MatchesLayoutInShape(call->shape()));
+ return Status::OK();
+}
+
+// Custom calls have fixed input and output layouts.
+Status CheckCustomCallLayout(HloInstruction* custom_call) {
+ for (const HloInstruction* operand : custom_call->operands()) {
+ TF_RET_CHECK(
+ LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
+ }
+ TF_RET_CHECK(
+ LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout()));
+ return Status::OK();
+}
+
+// For a while instruction, all the following layouts must be the same:
+// (1) init operand
+// (2) condition computation parameter
+// (3) body computation parameter
+// (4) body computation result
+// (5) while instruction result
+Status CheckWhileLayout(HloInstruction* while_inst,
+ const ComputationLayout& condition_computation_layout,
+ const ComputationLayout& body_computation_layout) {
+ auto init_shape = while_inst->operand(0)->shape();
+ TF_RET_CHECK(
+ condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ init_shape));
+ TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ init_shape));
+ TF_RET_CHECK(
+ body_computation_layout.result_layout().MatchesLayoutInShape(init_shape));
+ TF_RET_CHECK(
+ LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
+ return Status::OK();
+}
+
+// Fusion parameters must match the layout of the fusion instructions operands,
+// and the root of the fusion expression must match the layout of the fusion
+// instruction.
+Status CheckFusionLayout(HloInstruction* fusion) {
+ TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
+
+ TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
+ fusion->shape(), fusion->fused_expression_root()->shape()));
+ for (int64 i = 0; i < fusion->operand_count(); ++i) {
+ TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
+ fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape()));
+ }
+ return Status::OK();
+}
+
+// The layout of a parameter must match the respective layout in the
+// computation's ComputationLayout.
+Status CheckParameterLayout(HloInstruction* parameter,
+ const ComputationLayout& computation_layout) {
+ const ShapeLayout& parameter_layout =
+ computation_layout.parameter_layout(parameter->parameter_number());
+ if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) {
+ return InternalError(
+ "parameter instruction %s does not match layout of computation "
+ "shape: %s",
+ parameter->ToString().c_str(), parameter_layout.ToString().c_str());
+ }
+ return Status::OK();
+}
+
+// The layout of a constant instruction must match the layout of its literal.
+Status CheckConstantLayout(HloInstruction* constant) {
+ if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(),
+ constant->shape())) {
+ return InternalError(
+ "constant instruction %s does not match the layout of its literal %s",
+ constant->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str());
+ }
+ return Status::OK();
+}
+
+// Check that all layouts in the module have been set and satisfy all necessary
+// conditions.
+Status CheckLayouts(
+ HloModule* module,
+ const std::map<HloComputation*, ComputationLayout>& computation_layouts) {
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+ for (auto& computation : module->computations()) {
+ for (auto& instruction : computation->instructions()) {
+ // Verify every instruction has a layout and the layout is valid for the
+ // shape.
+ TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
+
+ // Use points-to analysis to verify that every subshape element in the
+ // output of the instruction matches the layout of the logical buffer
+ // which could be the source of the subshape value.
+ const PointsToSet& points_to_set =
+ points_to_analysis->GetPointsToSet(instruction.get());
+ TF_RETURN_IF_ERROR(points_to_set.ForEachElement(
+ [&instruction](
+ ShapeIndex index, bool is_leaf,
+ const std::vector<const LogicalBuffer*>& buffers) -> Status {
+ if (is_leaf) {
+ const Shape& instruction_subshape =
+ ShapeUtil::GetSubshape(instruction->shape(), index);
+ for (const LogicalBuffer* buffer : buffers) {
+ if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) {
+ return InternalError(
+ "Layout of instruction %s at index {%s} does not match "
+ "source LogicalBuffer %s: %s vs %s",
+ instruction->name().c_str(),
+ tensorflow::str_util::Join(index, ",").c_str(),
+ buffer->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(instruction_subshape)
+ .c_str(),
+ ShapeUtil::HumanStringWithLayout(buffer->shape())
+ .c_str());
+ }
+ }
+ }
+ return Status::OK();
+ }));
+
+ // Verify instructions that have special layout constraints.
+ switch (instruction->opcode()) {
+ case HloOpcode::kCall:
+ TF_RETURN_IF_ERROR(CheckCallLayout(
+ instruction.get(),
+ FindOrDie(computation_layouts, instruction->to_apply())));
+ break;
+ case HloOpcode::kCustomCall:
+ TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction.get()));
+ break;
+ case HloOpcode::kFusion:
+ TF_RETURN_IF_ERROR(CheckFusionLayout(instruction.get()));
+ break;
+ case HloOpcode::kParameter:
+ TF_RETURN_IF_ERROR(CheckParameterLayout(
+ instruction.get(),
+ FindOrDie(computation_layouts, instruction->parent())));
+ break;
+ case HloOpcode::kConstant:
+ TF_RETURN_IF_ERROR(CheckConstantLayout(instruction.get()));
+ break;
+ case HloOpcode::kWhile:
+ TF_RETURN_IF_ERROR(CheckWhileLayout(
+ instruction.get(),
+ FindOrDie(computation_layouts, instruction->while_condition()),
+ FindOrDie(computation_layouts, instruction->while_body())));
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ // Finally verify the result layout matches the layout of the entry
+ // computation root.
+ TF_RET_CHECK(ShapeUtil::Equal(
+ module->entry_computation()->root_instruction()->shape(),
+ FindOrDie(computation_layouts, module->entry_computation())
+ .result_layout()
+ .shape()));
+
+ return Status::OK();
+}
+
+} // namespace
+
+LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout)
+ : HloPass("layout-assignment"),
+ entry_computation_layout_(entry_computation_layout) {
+ VLOG(1) << "entry computation layout given to layout assignment: "
+ << entry_computation_layout_->ToString();
+ // Layouts of all parameter instructions must be set.
+ for (const ShapeLayout& parameter_layout :
+ entry_computation_layout_->parameter_layouts()) {
+ CHECK(parameter_layout.LayoutIsSet());
+ }
+ // If the result layout is not set, then choose the default.
+ // TODO(b/29118294): Choose a better layout in this case.
+ if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
+ entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
+ }
+}
+
+namespace {
+
+// Given a pemutation of `{0, 1, ..., n}` `indices`, returns a permutation of
+// `{0, 1, ..., n - to_delete.size() + to_insert.size()}` by deleting the
+// indices `to_delete` wherever in `indices` they are, and inserting the indices
+// `to_insert` arbitrarily at the back.
+tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>
+DeleteAndInsertIndices(
+ std::vector<int64> to_delete, std::vector<int64> to_insert,
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64> indices) {
+ std::sort(to_delete.begin(), to_delete.end(), std::greater<int64>());
+ std::sort(to_insert.begin(), to_insert.end(), std::less<int64>());
+ for (auto index : to_delete) {
+ auto i = indices.begin();
+ while (i != indices.end()) {
+ if (*i == index) {
+ i = indices.erase(i);
+ } else {
+ if (*i > index) {
+ (*i)--;
+ }
+ ++i;
+ }
+ }
+ }
+ for (auto index : to_insert) {
+ for (auto i = indices.begin(); i != indices.end(); ++i) {
+ if (*i >= index) {
+ (*i)++;
+ }
+ }
+ indices.Add(index);
+ }
+ return indices;
+}
+
+} // namespace
+
+std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
+ const Layout& output_layout, const HloInstruction* instruction,
+ int64 operand_no) {
+ const HloInstruction* operand = instruction->operand(operand_no);
+
+ CHECK(ShapeUtil::IsArray(instruction->shape()) &&
+ ShapeUtil::IsArray(operand->shape()));
+
+ if (instruction->IsElementwiseOnOperand(operand_no) &&
+ !ShapeUtil::IsScalar(operand->shape()) &&
+ ShapeUtil::Rank(operand->shape()) ==
+ ShapeUtil::Rank(instruction->shape())) {
+ // Assign operands the same layout as the instruction, so that
+ // 1) the elementwise operation can reuse its operand's buffer, and
+ // 2) the input and output elements can reuse the same linear index.
+ //
+ // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit
+ // from assigning the same layout to input and output.
+ return MakeUnique<Layout>(output_layout);
+ }
+
+ if (instruction->opcode() == HloOpcode::kReshape) {
+ // Pick the operand layout that makes the reshape a bitcast. If the reshape
+ // only inserts or deletes degenerate dimensions, we can easily compute the
+ // desired layout by accordingly inserting and deleting the elements in the
+ // minor-to-major list.
+ bool merely_inserts_or_deletes_1_sized_dims;
+ std::vector<int64> inserted_indices, deleted_indices;
+ std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices,
+ inserted_indices) =
+ instruction->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+ if (merely_inserts_or_deletes_1_sized_dims) {
+ Layout operand_layout = LayoutUtil::MakeLayout(
+ AsInt64Slice(DeleteAndInsertIndices(inserted_indices, deleted_indices,
+ output_layout.minor_to_major())));
+ TF_CHECK_OK(
+ LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
+ return MakeUnique<Layout>(operand_layout);
+ }
+ }
+
+ if (instruction->opcode() == HloOpcode::kTranspose) {
+ // Pick the operand layout that makes the transpose a bitcast.
+ std::vector<int64> perm =
+ ComposePermutations(instruction->dimensions(),
+ AsInt64Slice(output_layout.minor_to_major()));
+ Layout operand_layout = LayoutUtil::MakeLayout(perm);
+ TF_CHECK_OK(
+ LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
+ return MakeUnique<Layout>(operand_layout);
+ }
+
+ return nullptr;
+}
+
+std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
+ const Layout& operand_layout, const HloInstruction* user,
+ int64 operand_no) {
+ const HloInstruction* operand = user->operand(operand_no);
+
+ CHECK(ShapeUtil::IsArray(user->shape()) &&
+ ShapeUtil::IsArray(operand->shape()));
+
+ if (user->IsElementwiseOnOperand(operand_no) &&
+ !ShapeUtil::IsScalar(operand->shape()) &&
+ ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) {
+ // Assign users the same layout as the operand.
+ return MakeUnique<Layout>(operand_layout);
+ }
+
+ if (user->opcode() == HloOpcode::kReshape) {
+ // Pick the user layout that makes the reshape a bitcast.
+ bool merely_inserts_or_deletes_1_sized_dims;
+ std::vector<int64> inserted_indices, deleted_indices;
+ std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices,
+ inserted_indices) =
+ user->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+ if (merely_inserts_or_deletes_1_sized_dims) {
+ Layout user_layout = LayoutUtil::MakeLayout(AsInt64Slice(
+ DeleteAndInsertIndices(deleted_indices, inserted_indices,
+ operand_layout.minor_to_major())));
+ TF_CHECK_OK(
+ LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
+ return MakeUnique<Layout>(user_layout);
+ }
+ }
+
+ if (user->opcode() == HloOpcode::kTranspose) {
+ // Pick the user layout that makes the reshape a bitcast.
+ // To become a bitcast, the layouts need to satisfy
+ // collapsing_order * output_layout = input_layout
+ // so output_layout = inverse(collapsing_order) * input_layout
+ std::vector<int64> perm =
+ Permute(InversePermutation(user->dimensions()),
+ AsInt64Slice(operand_layout.minor_to_major()));
+ Layout user_layout = LayoutUtil::MakeLayout(perm);
+ TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
+ return MakeUnique<Layout>(user_layout);
+ }
+
+ return nullptr;
+}
+
+Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
+ // Gathers all initial constraints in a worklist and propagates them in
+ // depth-first order. DFS order seems to be better than BFS because a
+ // constraint is propagated as far as possible before propagating unrelated
+ // constraints which makes it less likely that conflicting constraints will be
+ // propagated to instructions. However, we should experiment with other orders
+ // too.
+ std::deque<const LayoutConstraint*> worklist;
+
+ // Lambda for moving newly added constraints to the worklist.
+ auto add_new_constraints_to_worklist = [constraints, &worklist]() {
+ // Add constraints to the front of the deque for DFS ordering.
+ for (auto* constraint : constraints->ConsumeAddedConstraints()) {
+ worklist.push_front(constraint);
+ }
+ };
+ add_new_constraints_to_worklist();
+
+ while (!worklist.empty()) {
+ const LayoutConstraint* layout_constraint = worklist.front();
+ worklist.pop_front();
+ VLOG(2) << "Propagating " << layout_constraint->ToString()
+ << " to its neighbors.";
+ if (auto* buffer_constraint =
+ dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
+ TF_RETURN_IF_ERROR(
+ PropagateBufferConstraint(*buffer_constraint, constraints));
+ } else if (auto* operand_constraint =
+ dynamic_cast<const OperandLayoutConstraint*>(
+ layout_constraint)) {
+ TF_RETURN_IF_ERROR(
+ PropagateOperandConstraint(*operand_constraint, constraints));
+ } else if (auto* result_constraint =
+ dynamic_cast<const ResultLayoutConstraint*>(
+ layout_constraint)) {
+ TF_RETURN_IF_ERROR(
+ PropagateResultConstraint(*result_constraint, constraints));
+ } else {
+ LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
+ }
+
+ add_new_constraints_to_worklist();
+ }
+ return Status::OK();
+}
+
+namespace {
+
+// Returns a vector containing all array-shaped uses (instruction and operand
+// number) of the given logical buffer or its aliases.
+std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
+ const LogicalBuffer& buffer,
+ const TuplePointsToAnalysis& points_to_analysis) {
+ CHECK(buffer.IsArray());
+ std::vector<std::pair<const HloInstruction*, int64>> uses;
+ for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
+ if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) {
+ continue;
+ }
+ // This alias must be the top-level (index == {}) of the instruction's
+ // result because the instruction produces an array.
+ CHECK(buffer_alias.index().empty());
+
+ // Add all uses of the instruction's output.
+ for (const HloInstruction* user : buffer_alias.instruction()->users()) {
+ for (int64 operand_no :
+ user->OperandIndices(buffer_alias.instruction())) {
+ uses.emplace_back(user, operand_no);
+ }
+ }
+ }
+ return uses;
+}
+
+} // namespace
+
+Status LayoutAssignment::PropagateUseConstraintToDefs(
+ const ShapeLayout& shape_layout, const HloInstruction* instruction,
+ LayoutConstraints* constraints) {
+ // Try to set all logical buffers which may be sources of the given operand to
+ // match the given layout.
+ const PointsToSet& points_to_set =
+ constraints->points_to_analysis().GetPointsToSet(instruction);
+ return points_to_set.ForEachElement(
+ [this, &shape_layout, constraints](
+ const ShapeIndex& index, bool is_leaf,
+ const std::vector<const LogicalBuffer*>& buffers) -> Status {
+ if (is_leaf) {
+ for (const LogicalBuffer* buffer : buffers) {
+ if (constraints->BufferLayout(*buffer) == nullptr &&
+ ShapeUtil::IsArray(buffer->shape())) {
+ TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
+ ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
+ *buffer));
+ }
+ }
+ }
+ return Status::OK();
+ });
+}
+
+Status LayoutAssignment::PropagateOperandConstraint(
+ const OperandLayoutConstraint& operand_constraint,
+ LayoutConstraints* constraints) {
+ // Try to set the layout of the logical buffers in the given operand to match
+ // the constrained layout. This avoids copies.
+ TF_RETURN_IF_ERROR(
+ PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
+ operand_constraint.operand(), constraints));
+
+ // For array-shaped operands and user instructions try to pick a minimum cost
+ // layout. For example, if the operand of a elementwise instruction is
+ // constained to a certain layout we want the output of the instruction to
+ // have the same layout.
+ const HloInstruction* operand = operand_constraint.operand();
+ const HloInstruction* user = operand_constraint.instruction();
+ if (!ShapeUtil::IsArray(operand->shape()) ||
+ !ShapeUtil::IsArray(user->shape())) {
+ return Status::OK();
+ }
+
+ // Only try to choose a low cost layout if the instruction 'user' defines its
+ // output (ie, doesn't forward a buffer from elsewhere).
+ if (constraints->OperandBufferForwarded(user,
+ operand_constraint.operand_no())) {
+ return Status::OK();
+ }
+ TF_ASSIGN_OR_RETURN(
+ const LogicalBuffer* buffer,
+ constraints->points_to_analysis().GetBufferDefinedAt(user, /*index=*/{}));
+
+ if (constraints->BufferLayout(*buffer) == nullptr) {
+ std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
+ operand_constraint.shape_layout().layout(), user,
+ operand_constraint.operand_no());
+ if (layout != nullptr) {
+ TF_RETURN_IF_ERROR(constraints->SetBufferLayout(*layout, *buffer));
+ }
+ }
+ return Status::OK();
+}
+
+Status LayoutAssignment::PropagateBufferConstraint(
+ const BufferLayoutConstraint& buffer_constraint,
+ LayoutConstraints* constraints) {
+ // Only propagate array layouts.
+ const LogicalBuffer& buffer = buffer_constraint.buffer();
+ if (!buffer.IsArray()) {
+ return Status::OK();
+ }
+
+ // If this buffer is the result of an array-shaped op (as opposed to an array
+ // element in a tuple) try to propagate the layout to its operands.
+ if (buffer.IsTopLevel()) {
+ const HloInstruction* instruction = buffer.instruction();
+ // Propagate the def-constraint on an instruction to the use-constraints on
+ // its operands (use-def propagation).
+ for (int64 operand_no = 0; operand_no < instruction->operand_count();
+ ++operand_no) {
+ if (constraints->OperandLayout(instruction, operand_no) == nullptr &&
+ ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) {
+ std::unique_ptr<Layout> operand_layout =
+ ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
+ instruction, operand_no);
+ if (operand_layout != nullptr) {
+ TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
+ *operand_layout, instruction, operand_no));
+ }
+ }
+ }
+ }
+
+ // Propagate the layout to all array uses of the logical buffer. This skips
+ // uses of the buffer where the buffer is the element of a tuple.
+ for (const auto& user_operand_no :
+ GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
+ const HloInstruction* user = user_operand_no.first;
+ int64 operand_no = user_operand_no.second;
+ // Only add an operand constraint if the user does not forward the buffer
+ // because this case is not handled is SetOperandLayout.
+ if (constraints->OperandLayout(user, operand_no) == nullptr &&
+ !constraints->OperandBufferForwarded(user, operand_no)) {
+ TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
+ buffer_constraint.layout(), user, operand_no));
+ }
+ }
+
+ return Status::OK();
+}
+
+Status LayoutAssignment::PropagateResultConstraint(
+ const ResultLayoutConstraint& result_constraint,
+ LayoutConstraints* constraints) {
+ // Propagate the use constraint of the root instruction up to the logical
+ // buffers which make up the result.
+ return PropagateUseConstraintToDefs(
+ result_constraint.shape_layout(),
+ constraints->computation()->root_instruction(), constraints);
+}
+
+namespace {
+
+// Infers the layout of the array at the given index in the given instruction's
+// output using points-to analysis. Precondition: The given instruction must
+// not produce this array value (that is, the array is forwarded from the
+// instruction's operands).
+StatusOr<Layout> InferArrayLayout(
+ const TuplePointsToAnalysis& points_to_analysis,
+ HloInstruction* instruction, const ShapeIndex& index) {
+ // This function should only be called for array shapes which don't yet have
+ // layouts.
+ const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
+ TF_RET_CHECK(ShapeUtil::IsArray(subshape));
+ TF_RET_CHECK(!subshape.has_layout());
+
+ // The instruction should not define the buffer at this index.
+ TF_RET_CHECK(
+ !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index));
+
+ const std::vector<const LogicalBuffer*>& source_buffers =
+ points_to_analysis.GetPointsToSet(instruction).element(index);
+ TF_RET_CHECK(!source_buffers.empty());
+
+ // Verify the layout is the same for every LogicalBuffer which this location
+ // ('instruction' and 'index') points to.
+ const Layout* first_buffer_layout = nullptr;
+ for (const LogicalBuffer* source_buffer : source_buffers) {
+ if (!source_buffer->shape().has_layout()) {
+ // This should not happen because we've assigned layouts to all
+ // instructions preceding this one.
+ return InternalError("LogicalBuffer %s does not have a layout",
+ source_buffer->ToString().c_str());
+ }
+
+ if (first_buffer_layout == nullptr) {
+ first_buffer_layout = &source_buffer->shape().layout();
+ } else if (!LayoutUtil::Equal(source_buffer->shape().layout(),
+ *first_buffer_layout)) {
+ // The points-to set is ambiguous for this index and the different source
+ // buffers have different layouts. This case is possible in valid XLA
+ // computations because we do not propagate BufferLayoutConstaints to all
+ // LogicalBuffers which may alias the constrained LogicalBuffer at some
+ // point in the computation.
+ return FailedPrecondition(
+ "Array at index {%s} in instruction %s aliases buffers %s "
+ "and %s which have different layouts",
+ tensorflow::str_util::Join(index, ",").c_str(),
+ instruction->name().c_str(), source_buffers[0]->ToString().c_str(),
+ source_buffer->ToString().c_str());
+ }
+ }
+
+ return *first_buffer_layout;
+}
+
+// Creates and returns a copy of the given instruction with a different
+// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
+// instruction producing the copy is returned.
+StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+ const Shape& shape_with_layout, HloInstruction* instruction) {
+ TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
+ DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()));
+
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // Deep-copy tuples.
+ std::vector<HloInstruction*> element_copies;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
+ ++i) {
+ HloInstruction* gte = instruction->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
+ i));
+
+ // Recurse to copy each elements.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * element_copy,
+ CreateCopyWithNewLayout(
+ ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
+ element_copies.push_back(element_copy);
+ }
+ // Gather element copies into a tuple with a new Tuple instruction.
+ HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
+ HloInstruction::CreateTuple(element_copies));
+ LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, tuple_copy->mutable_shape()));
+ return tuple_copy;
+ } else if (ShapeUtil::IsArray(instruction->shape())) {
+ HloInstruction* copy =
+ instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
+ instruction->shape(), HloOpcode::kCopy, instruction));
+ LayoutUtil::ClearLayout(copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, copy->mutable_shape()));
+
+ return copy;
+ } else {
+ return FailedPrecondition(
+ "Can only copy array and tuple shaped instructions");
+ }
+}
+
+// Creates a copy of the given operand if the operand's layout does not match
+// the given layout. This copy replaces the use in the given instruction. Tuple
+// operands will be deep-copied.
+Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+ HloInstruction* instruction,
+ int64 operand_no) {
+ HloInstruction* operand = instruction->mutable_operand(operand_no);
+ TF_RET_CHECK(operand_layout.LayoutIsSet());
+ TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
+
+ if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
+ // Operand layout already matches our constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
+ CreateCopyWithNewLayout(operand_layout.shape(), operand));
+
+ instruction->ReplaceOperandWith(operand_no, operand_copy);
+ return Status::OK();
+}
+
+// For fusion instructions, set the layout of each fused parameter instruction
+// to match the layout of its corresponding fusion instruction operand. Also,
+// set the layout of the fused root to match the layout of the fusion
+// instruction itself.
+// Fused GetTupleElement requires a layout so that TBAA metadata for the tuple
+// element array pointer load can be added.
+Status SetFusionLayouts(HloInstruction* fusion) {
+ TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
+ for (auto& fused_instruction : fusion->fused_instructions()) {
+ if (fused_instruction->opcode() == HloOpcode::kParameter) {
+ const HloInstruction* fusion_operand =
+ fusion->operand(fused_instruction->parameter_number());
+ DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
+ fused_instruction->shape()));
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ fusion_operand->shape(), fused_instruction->mutable_shape()));
+ } else if (fused_instruction.get() == fusion->fused_expression_root()) {
+ // The layout of the root of the fused expression must match the fusion
+ // instruction layout.
+ DCHECK(
+ ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ fusion->shape(), fused_instruction->mutable_shape()));
+ } else if (fused_instruction->opcode() != HloOpcode::kConstant &&
+ fused_instruction->opcode() != HloOpcode::kGetTupleElement &&
+ fused_instruction->opcode() != HloOpcode::kInfeed) {
+ // Internal fused instructions with the exception of constants
+ // and infeed need no layout.
+ LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
+ HloComputation* computation) {
+ VLOG(2) << "Assigning layouts to computation: " << computation->name();
+ XLA_VLOG_LINES(2, computation->ToString());
+ XLA_VLOG_LINES(2, constraints.ToString());
+
+ for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
+ LayoutUtil::ClearLayout(instruction->mutable_shape());
+
+ // Create a copy of an operand if the operand instruction's layout does not
+ // match the use constraint (OperandLayoutConstraint).
+ for (int64 operand_no = 0; operand_no < instruction->operand_count();
+ ++operand_no) {
+ const ShapeLayout* operand_layout =
+ constraints.OperandLayout(instruction, operand_no);
+ if (operand_layout != nullptr) {
+ TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
+ instruction, operand_no));
+ }
+ }
+
+ // Set the layouts of the array shapes this instruction defines as
+ // indicated by the respective BufferLayoutConstraints. Any array shapes
+ // in the output of the instruction which are not defined by the instruction
+ // (eg, array elements in a Tuple instruction) will be assigned below via
+ // inference.
+ for (const LogicalBuffer* buffer :
+ constraints.points_to_analysis().GetBuffersDefinedByInstruction(
+ instruction)) {
+ if (!ShapeUtil::IsArray(buffer->shape())) {
+ continue;
+ }
+
+ TF_RET_CHECK(buffer->instruction() == instruction);
+ Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
+ instruction->mutable_shape(), buffer->index());
+ const Layout* buffer_layout = constraints.BufferLayout(*buffer);
+ TF_RET_CHECK(buffer_layout != nullptr);
+ *buffer_subshape->mutable_layout() = *buffer_layout;
+ }
+
+ // Any remaining layouts in the output of the instruction must be
+ // inferrable using points-to analysis.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshape(
+ instruction->mutable_shape(),
+ [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
+ if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) {
+ return Status::OK();
+ }
+ // Set Layout of subshape to match layout of LogicalBuffer which
+ // produces it.
+ TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
+ InferArrayLayout(constraints.points_to_analysis(),
+ instruction, index));
+ return Status::OK();
+ }));
+
+ // Fusion instructions require some layouts to be set on fused instructions
+ // inside the fusion instruction.
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
+ }
+
+ // Verify all layouts in the shape have been set.
+ TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
+ }
+
+ // Copy the root instrucion's result if the it does not match the result
+ // layout constraint
+ if (constraints.ResultLayout() != nullptr &&
+ !constraints.ResultLayout()->MatchesLayoutInShape(
+ computation->root_instruction()->shape())) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_root,
+ CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
+ computation->root_instruction()));
+ computation->set_root_instruction(new_root);
+ }
+
+ return Status::OK();
+}
+
+Status LayoutAssignment::RunOnComputation(
+ const ComputationLayout& computation_layout, HloComputation* computation) {
+ DCHECK(computation_layout.LayoutIsSet());
+ InsertOrDie(&computation_layouts_, computation, computation_layout);
+ VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
+ << ")";
+ VLOG(2) << " ComputationLayout = " << computation_layout.ToString();
+
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(computation->parent()));
+
+ // Construct LayoutConstaints with all layout constraints of the computation.
+ LayoutConstraints constraints(*points_to_analysis, computation);
+
+ // Add constraints required for correctness on all backends (eg, entry
+ // parameter layout constraints).
+ TF_RETURN_IF_ERROR(
+ AddMandatoryConstraints(computation_layout, computation, &constraints));
+
+ // Add any backend-specific constraints.
+ TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
+
+ // Propagates layouts from an HLO to its neighbors.
+ TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
+
+ // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
+ // layout and propagate the change.
+ while (!constraints.unconstrained_buffer_ids().empty()) {
+ int unconstrained_count = constraints.unconstrained_buffer_ids().size();
+
+ // Arbitrarily pick the first unconstrained buffer and give it the default
+ // layout. By construction unconstrained_buffers() has a stable sort based
+ // on LogicalBuffer::Id.
+ const LogicalBuffer& buffer = points_to_analysis->GetBuffer(
+ *constraints.unconstrained_buffer_ids().begin());
+ TF_RETURN_IF_ERROR(constraints.SetBufferLayout(
+ LayoutUtil::GetDefaultLayoutForShape(buffer.shape()), buffer));
+
+ TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
+
+ // To verify progress has been made, check that the number of unconstrained
+ // buffers has been reduced.
+ CHECK_LT(constraints.unconstrained_buffer_ids().size(),
+ unconstrained_count);
+ }
+
+ // All logical buffers should have constraints at this point. All that
+ // remains is assign the constraints to the buffers and infer layouts for
+ // aliased buffers.
+ return AssignLayouts(constraints, computation);
+}
+
+StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
+ VLOG(2) << "Running layout assignment on module " << module->name();
+ XLA_VLOG_LINES(3, module->ToString());
+ if (VLOG_IS_ON(10)) {
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
+ "before layout assignment",
+ /*show_addresses=*/false,
+ /*show_layouts=*/true);
+ }
+
+ // Assign layouts to computations in an order such that a callee computation
+ // is handled before its caller computation. This ensures that the layout of
+ // all callers of a computation will agree.
+ for (auto* computation : module->MakeComputationPostOrder()) {
+ if (computation == module->entry_computation()) {
+ TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_,
+ module->entry_computation()));
+ } else {
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ // Setting all embedded computations to the default layout is potentially
+ // suboptimal.
+ computation_layout.SetToDefaultLayout();
+ TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation));
+ }
+ }
+
+ TF_RETURN_IF_ERROR(CheckLayouts(module, computation_layouts_));
+
+ VLOG(3) << "After layout assignment:";
+ XLA_VLOG_LINES(3, module->ToString());
+ if (VLOG_IS_ON(10)) {
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
+ "after layout assignment",
+ /*show_addresses=*/false,
+ /*show_layouts=*/true);
+ }
+
+ // All layouts are reset then reassigned by this pass.
+ return true;
+}
+
+} // namespace xla