aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.cc7
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h8
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc328
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h65
-rw-r--r--tensorflow/compiler/xla/service/service.cc5
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.cc25
8 files changed, 325 insertions, 121 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9555d91817..bc577c173d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1953,10 +1953,12 @@ cc_library(
deps = [
":computation_layout",
":hlo",
+ ":hlo_dce",
":hlo_graph_dumper",
":hlo_pass",
":logical_buffer",
":tuple_points_to_analysis",
+ ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -2433,6 +2435,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc
index d2d4f14fce..cb61f3da39 100644
--- a/tensorflow/compiler/xla/service/computation_layout.cc
+++ b/tensorflow/compiler/xla/service/computation_layout.cc
@@ -23,12 +23,15 @@ limitations under the License.
namespace xla {
-ComputationLayout::ComputationLayout(const ProgramShape& program_shape)
+ComputationLayout::ComputationLayout(const ProgramShape& program_shape,
+ bool ignore_layouts)
: result_layout_(program_shape.result()) {
for (auto& shape : program_shape.parameters()) {
parameter_layouts_.emplace_back(shape);
}
- SetToDefaultLayout();
+ if (ignore_layouts) {
+ SetToDefaultLayout();
+ }
}
void ComputationLayout::SetToDefaultLayout() {
diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h
index 80e102411c..53c3a3f7b7 100644
--- a/tensorflow/compiler/xla/service/computation_layout.h
+++ b/tensorflow/compiler/xla/service/computation_layout.h
@@ -34,8 +34,9 @@ class ComputationLayout {
public:
// Constructs a ComputationLayout from a ProgramShape. The layouts of the
// parameters and results are set to the default layout. Layouts in the
- // ProgramShape are ignored.
- explicit ComputationLayout(const ProgramShape& program_shape);
+ // ProgramShape are ignored if ignore_layouts is true.
+ explicit ComputationLayout(const ProgramShape& program_shape,
+ bool ignore_layouts = true);
// Returns the layout of a particular parameter.
const ShapeLayout& parameter_layout(int64 param_no) const {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index a5e9aecb9e..f3da3fc256 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -956,6 +956,14 @@ class HloInstruction {
void clear_sharding() { sharding_ = nullptr; }
// Return true if this operator has a sharding assigned.
bool has_sharding() const { return sharding_ != nullptr; }
+ // Checks whether the instruction has compatible sharding with the other
+ // instruction.
+ bool has_compatible_sharding(const HloInstruction* other) const {
+ if (!has_sharding()) {
+ return !other->has_sharding();
+ }
+ return other->has_sharding() ? sharding() == other->sharding() : false;
+ }
// When creating a new instruction which either replaces, or shifts up (kCopy
// insertion case), another instruction, we need to make sure the certain
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 2494569db5..7067b6f86a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -31,10 +31,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -400,9 +402,9 @@ string LayoutConstraints::ToString() const {
}
Status LayoutAssignment::AddMandatoryConstraints(
- const ComputationLayout& computation_layout,
- const ChannelLayoutConstraints* channel_constraints,
- HloComputation* computation, LayoutConstraints* constraints) {
+ const ComputationLayout* computation_layout,
+ ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
+ LayoutConstraints* constraints) {
VLOG(3) << "Adding mandatory layout constraints to computation "
<< computation->name();
@@ -424,11 +426,16 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
instruction->outfeed_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kParameter) {
- // Parameter layouts must match the respective layout in
- // ComputationLayout.
- shape_with_layout =
- &computation_layout.parameter_layout(instruction->parameter_number())
- .shape();
+ if (computation_layout != nullptr) {
+ const ShapeLayout& parameter_layout =
+ computation_layout->parameter_layout(
+ instruction->parameter_number());
+ if (parameter_layout.LayoutIsSet()) {
+ // Parameter layouts must match the respective layout in
+ // ComputationLayout, if there is one.
+ shape_with_layout = &parameter_layout.shape();
+ }
+ }
}
if (shape_with_layout != nullptr) {
TF_RETURN_IF_ERROR(
@@ -493,9 +500,8 @@ Status LayoutAssignment::AddMandatoryConstraints(
HloComputation* body = instruction->while_body();
HloComputation* condition = instruction->while_condition();
const HloInstruction* init = instruction->operand(0);
- const ComputationLayout& body_layout =
- FindOrDie(computation_layouts_, body);
- const ComputationLayout& condition_layout =
+ ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
+ ComputationLayout& condition_layout =
FindOrDie(computation_layouts_, condition);
// Check a few invariants irrespective of layout.
@@ -508,26 +514,19 @@ Status LayoutAssignment::AddMandatoryConstraints(
condition_layout.parameter_shape(0)));
DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
- // Return error if earlier layout assignment of the embedded computations
- // has produced conflicting layouts.
- if (!ShapeUtil::Equal(body_layout.result_shape(),
- body_layout.parameter_shape(0))) {
- return InternalError(
- "Parameter and result of body computation %s of while instruction "
- "%s have different layouts: %s vs %s",
- body->name().c_str(), instruction->name().c_str(),
- ShapeUtil::HumanString(body_layout.result_shape()).c_str(),
- ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str());
+ if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
+ VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
+ << " while=" << instruction->name()
+ << " shape=" << body_layout.result_layout().ToString();
+ *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
}
- if (!ShapeUtil::Equal(body->root_instruction()->shape(),
- condition->parameter_instruction(0)->shape())) {
- return InternalError(
- "Parameter of condition computation %s of while instruction "
- "%s does not match body computation %s result: %s vs %s",
- condition->name().c_str(), instruction->name().c_str(),
- body->name().c_str(),
- ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(),
- ShapeUtil::HumanString(body_layout.result_shape()).c_str());
+ if (condition_layout.parameter_layout(0) !=
+ body_layout.parameter_layout(0)) {
+ VLOG(2) << "Reset %while condition parameter layout: cond="
+ << condition->name() << " while=" << instruction->name()
+ << " shape=" << body_layout.parameter_layout(0).ToString();
+ *condition_layout.mutable_parameter_layout(0) =
+ body_layout.parameter_layout(0);
}
// Constrain the output and the operand of the while instruction to match
@@ -557,7 +556,20 @@ Status LayoutAssignment::AddMandatoryConstraints(
true_computation_layout.parameter_shape(0)));
DCHECK(ShapeUtil::Compatible(
false_operand->shape(), false_computation_layout.parameter_shape(0)));
-
+ if (true_computation_layout.result_layout() !=
+ false_computation_layout.result_layout()) {
+ // We assign layouts in DFS fashion, so the true and false computations
+ // might have negotiated a different layout. But for the conditional
+ // instruction POV the layout must match, so we run again on the false
+ // computation, this time with proper computation layout.
+ VLOG(2) << "Reset %conditional false computation result layout: "
+ "false_computation="
+ << false_computation->name()
+ << " conditional=" << instruction->name() << " shape="
+ << true_computation_layout.result_layout().ToString();
+ *false_computation_layout.mutable_result_layout() =
+ true_computation_layout.result_layout();
+ }
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
true_computation_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
@@ -593,10 +605,14 @@ Status LayoutAssignment::AddMandatoryConstraints(
}
}
}
-
- // Finally set the result layout to match ComputationLayout.
- return constraints->SetResultLayout(
- computation_layout.result_layout().shape());
+ // Finally set the result layout to match ComputationLayout, if there is one.
+ if (computation_layout != nullptr) {
+ const ShapeLayout& result_layout = computation_layout->result_layout();
+ if (result_layout.LayoutIsSet()) {
+ TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
+ }
+ }
+ return Status::OK();
}
namespace {
@@ -760,6 +776,7 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
HloInstruction* copy =
instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
instruction->shape(), HloOpcode::kCopy, instruction));
+ RegisterAddedCopy(copy);
SetupCopiedInstruction(*instruction, copy, {});
LayoutUtil::ClearLayout(copy->mutable_shape());
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
@@ -783,13 +800,19 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
+ VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
+ << instruction->ToString();
// Operand layout already matches our constraint. Nothing to do.
return Status::OK();
}
+ VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
+ << operand_layout.ToString() << " in " << instruction->ToString();
TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
CreateCopyWithNewLayout(operand_layout.shape(), operand));
+ VLOG(4) << "New copy of " << operand->ToString() << " is "
+ << operand_copy->ToString();
return instruction->ReplaceOperandWith(operand_no, operand_copy);
}
@@ -896,15 +919,16 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
}
}
}
-
- // Finally verify the result layout matches the layout of the entry
+ // Finally verify the result layout, if set, matches the layout of the entry
// computation root.
- TF_RET_CHECK(ShapeUtil::Equal(
- module->entry_computation()->root_instruction()->shape(),
+ const ShapeLayout& result_layout =
FindOrDie(computation_layouts_, module->entry_computation())
- .result_layout()
- .shape()));
-
+ .result_layout();
+ if (result_layout.LayoutIsSet()) {
+ TF_RET_CHECK(ShapeUtil::Equal(
+ module->entry_computation()->root_instruction()->shape(),
+ result_layout.shape()));
+ }
return Status::OK();
}
@@ -913,18 +937,13 @@ LayoutAssignment::LayoutAssignment(
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
- VLOG(1) << "entry computation layout given to layout assignment: "
+ VLOG(1) << "Entry computation layout given to layout assignment: "
<< entry_computation_layout_->ToString();
// Layouts of all parameter instructions must be set.
for (const ShapeLayout& parameter_layout :
entry_computation_layout_->parameter_layouts()) {
CHECK(parameter_layout.LayoutIsSet());
}
- // If the result layout is not set, then choose the default.
- // TODO(b/29118294): Choose a better layout in this case.
- if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
- entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
- }
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1484,16 +1503,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
return Status::OK();
}
+Status LayoutAssignment::CalculateComputationLayout(
+ HloComputation* computation) {
+ ComputationLayout computation_layout(computation->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ InsertOrDie(&computation_layouts_, computation, computation_layout);
+ VLOG(2) << " Calculated ComputationLayout = "
+ << computation_layout.ToString();
+ return Status::OK();
+}
+
+Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
+ // Clear existing layouts of the instructions. All layouts must be assigned
+ // by the LayoutAssignment pass, except for those on infeeds, parameters,
+ // and the computation result. The latter two are specified in
+ // computation_layout, so we only need to keep the existing layouts for
+ // infeeds. Clearing the layouts here avoids hiding potential bugs in the
+ // layout assignment pass that may accidently use the existing layout.
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kBitcast) {
+ // bitcasts are inherently layout sensitive and so a bitcast instruction
+ // present in the IR before layout assignment is a bug.
+ return InternalError(
+ "Unexpected bitcast operation seen during layout assignment: %s.",
+ instruction->ToString().c_str());
+ }
+ if (instruction->opcode() != HloOpcode::kInfeed) {
+ LayoutUtil::ClearLayout(instruction->mutable_shape());
+ }
+ }
+ return Status::OK();
+}
+
Status LayoutAssignment::RunOnComputation(
- const ComputationLayout& computation_layout,
+ ComputationLayout* computation_layout,
const TuplePointsToAnalysis& points_to_analysis,
HloComputation* computation,
ChannelLayoutConstraints* channel_constraints) {
- DCHECK(computation_layout.LayoutIsSet());
- InsertOrDie(&computation_layouts_, computation, computation_layout);
VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
<< ")";
- VLOG(2) << " ComputationLayout = " << computation_layout.ToString();
+ TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
+ if (computation_layout != nullptr) {
+ auto it = computation_layouts_.find(computation);
+ if (it == computation_layouts_.end()) {
+ VLOG(2) << " New ComputationLayout = " << computation_layout->ToString();
+ computation_layouts_.emplace(computation, *computation_layout);
+ } else {
+ TF_RET_CHECK(computation_layout == &it->second ||
+ computation_layout == entry_computation_layout_);
+ VLOG(2) << " Existing ComputationLayout = "
+ << computation_layout->ToString();
+ }
+ } else {
+ VLOG(2) << " No ComputationLayout specified (will be calculated)";
+ }
// Construct LayoutConstraints with all layout constraints of the computation.
LayoutConstraints constraints(points_to_analysis, computation);
@@ -1536,12 +1599,19 @@ Status LayoutAssignment::RunOnComputation(
CHECK_LT(constraints.unconstrained_buffer_ids().size(),
unconstrained_count);
}
-
// All logical buffers should have constraints at this point. All that
// remains is assign the constraints to the buffers and infer layouts for
// aliased buffers.
TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
+ // If the computation layout wasn't specified, now it is the time to compute
+ // it according to the parameters and root instruction layouts.
+ // This allows the first pass through this API to record the best flowing
+ // layout to parameters and root instruction.
+ if (computation_layout == nullptr) {
+ TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
+ }
+
// Record the layouts assigned for any communication ops in
// channel_constraints so that they are constrained for future modules.
for (HloInstruction* instruction : computation->instructions()) {
@@ -1556,6 +1626,34 @@ Status LayoutAssignment::RunOnComputation(
return Status::OK();
}
+Status LayoutAssignment::PropagateComputationLayouts(
+ HloComputation* computation, ComputationLayout* computation_layout) {
+ ComputationLayout computed_computation_layout(
+ computation->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
+ for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
+ ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
+ if (!param_layout->LayoutIsSet()) {
+ VLOG(4) << "Assigning layout to parameter " << i << " of computation "
+ << computation->name() << ": "
+ << computed_computation_layout.parameter_layout(i).ToString();
+ *param_layout = computed_computation_layout.parameter_layout(i);
+ } else {
+ TF_RET_CHECK(computed_computation_layout.parameter_layout(i) ==
+ *param_layout);
+ }
+ }
+ ShapeLayout* result_layout = computation_layout->mutable_result_layout();
+ if (!result_layout->LayoutIsSet()) {
+ VLOG(4) << "Assigning result layout of computation " << computation->name()
+ << ": " << computed_computation_layout.result_layout().ToString();
+ *result_layout = computed_computation_layout.result_layout();
+ } else {
+ TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout);
+ }
+ return Status::OK();
+}
+
StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
VLOG(2) << "Running layout assignment on module " << module->name();
XLA_VLOG_LINES(3, module->ToString());
@@ -1564,52 +1662,45 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
"before layout assignment",
module->config().debug_options());
}
-
- TF_ASSIGN_OR_RETURN(auto points_to_analysis,
- TuplePointsToAnalysis::Run(module));
-
- // Assign layouts to computations in an order such that a callee computation
- // is handled before its caller computation. This ensures that the layout of
- // all callers of a computation will agree.
- std::list<HloComputation*> computation_post_order =
- module->MakeComputationPostOrder();
- for (auto* computation : module->MakeComputationPostOrder()) {
- if (computation->IsFusionComputation()) {
- continue;
- }
- // Clear existing layouts of the instructions. All layouts must be assigned
- // by the LayoutAssignment pass, except for those on infeeds, parameters,
- // and the computation result. The latter two are specified in
- // computation_layout, so we only need to keep the existing layouts for
- // infeeds. Clearing the layouts here avoids hiding potential bugs in the
- // layout assignment pass that may accidently use the existing layout.
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kBitcast) {
- // bitcasts are inherently layout sensitive and so a bitcast instruction
- // present in the IR before layout assignment is a bug.
- return InternalError(
- "Unexpected bitcast operation seen during layout assignment: %s.",
- instruction->ToString().c_str());
+ TF_RETURN_IF_ERROR(Init());
+
+ // We do two passes. The first one we pass a nullptr ComputationLayout to
+ // the RunOnComputation() calls (for non entry computations), and we register
+ // the ComputationLayout which are naturally flowing in DFS fashion to the
+ // parameters and root instruction.
+ // Walking in DFS mode though, means that we can end up with incorrect layouts
+ // when seen from an outer instruction, which has across-computation
+ // constraints to impose.
+ // For example, the kWhile instruction needs to enforce the same layouts for
+ // the parameters and root of the bosy, as well as the condition parameters.
+ // Similarly, the kConditional instruction needs to enforce the same layouts
+ // for the root of the true and false computations.
+ // So in the first pass, while allowing the layouts to flow to parameters and
+ // root, we also fix up the eventually inconsistent ComputationLayout, which
+ // will be then made mandatory by the second pass.
+ for (int64 i = 0; i < 2; ++i) {
+ TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+ for (auto* computation : module->MakeComputationPostOrder()) {
+ if (computation->IsFusionComputation()) {
+ continue;
}
- if (instruction->opcode() != HloOpcode::kInfeed) {
- LayoutUtil::ClearLayout(instruction->mutable_shape());
+ if (computation == module->entry_computation()) {
+ TF_RETURN_IF_ERROR(RunOnComputation(
+ entry_computation_layout_, *points_to_analysis,
+ module->entry_computation(), channel_layout_constraints_));
+ } else {
+ ComputationLayout* computation_layout =
+ (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation);
+ TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
+ *points_to_analysis, computation,
+ channel_layout_constraints_));
}
}
- if (computation == module->entry_computation()) {
- TF_RETURN_IF_ERROR(RunOnComputation(
- *entry_computation_layout_, *points_to_analysis,
- module->entry_computation(), channel_layout_constraints_));
- } else {
- ComputationLayout computation_layout(computation->ComputeProgramShape());
- // Setting all embedded computations to the default layout is potentially
- // suboptimal.
- computation_layout.SetToDefaultLayout();
- TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
- *points_to_analysis, computation,
- channel_layout_constraints_));
- }
}
-
+ TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
+ entry_computation_layout_));
TF_RETURN_IF_ERROR(CheckLayouts(module));
VLOG(3) << "After layout assignment:";
@@ -1619,9 +1710,54 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
"after layout assignment",
module->config().debug_options());
}
-
// All layouts are reset then reassigned by this pass.
return true;
}
+Status LayoutAssignment::Init() {
+ computation_layouts_.clear();
+ return Status::OK();
+}
+
+Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
+ // Clear all the copies which have been added, and all the related
+ // instructions (like GTE and tuples).
+ int64 removed_copies = 0;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ if (instruction->opcode() == HloOpcode::kCopy &&
+ added_copies_.count(instruction) > 0) {
+ VLOG(5) << "Removing added copy: " << instruction->ToString();
+ TF_RETURN_IF_ERROR(
+ instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
+ ++removed_copies;
+ }
+ }
+ }
+ added_copies_.clear();
+ if (removed_copies > 0) {
+ TupleSimplifier tuple_simplifier;
+ HloDCE dce;
+ TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
+ TF_RETURN_IF_ERROR(dce.Run(module).status());
+ }
+ return Status::OK();
+}
+
+Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
+ int64 operand_number) {
+ HloInstruction* operand = instruction->mutable_operand(operand_number);
+ if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
+ HloInstruction* copy =
+ instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
+ operand->shape(), HloOpcode::kCopy, operand));
+ SetupCopiedInstruction(*operand, copy, {});
+ LayoutUtil::ClearLayout(copy->mutable_shape());
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
+ }
+ return Status::OK();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index ae4986d6ad..8b4e07995a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -362,12 +363,15 @@ class LayoutAssignment : public HloPassInterface {
int64 operand_no);
private:
+ // Initializes the layout assignment object for a new Run() call.
+ Status Init();
+
// Adds constraints which must be satisfied for correctness on all
// backends. Called once prior to propagating constraints.
- Status AddMandatoryConstraints(
- const ComputationLayout& computation_layout,
- const ChannelLayoutConstraints* channel_constraints,
- HloComputation* computation, LayoutConstraints* constraints);
+ Status AddMandatoryConstraints(const ComputationLayout* computation_layout,
+ ChannelLayoutConstraints* channel_constraints,
+ HloComputation* computation,
+ LayoutConstraints* constraints);
// This method can be overridden to add backend-specific constraints to the
// layout of the instructions of a computation. This method is called after
@@ -378,10 +382,12 @@ class LayoutAssignment : public HloPassInterface {
}
// Construct contraints and assign layouts to all instructions in the
- // computation satisfying the given ComputationLayout. Layouts constraints are
- // added, then propagated until all LogicalBuffers in the computation are
- // constrained.
- Status RunOnComputation(const ComputationLayout& computation_layout,
+ // computation satisfying the given ComputationLayout, if not nullptr.
+ // Otherwise the ComputationLayout will be calculated by propagating the
+ // computation instruction contraints.
+ // Layouts constraints are added, then propagated until all LogicalBuffers in
+ // the computation are constrained.
+ Status RunOnComputation(ComputationLayout* computation_layout,
const TuplePointsToAnalysis& points_to_analysis,
HloComputation* computation,
ChannelLayoutConstraints* channel_constraints);
@@ -402,6 +408,25 @@ class LayoutAssignment : public HloPassInterface {
// necessary conditions.
Status CheckLayouts(HloModule* module);
+ // Computes the ComputationLayout of the given computation based of the
+ // layouts assigned to parameters and root instruction, and inserts it to the
+ // computation_layouts_ map.
+ Status CalculateComputationLayout(HloComputation* computation);
+
+ // Clears all the layouts which can be cleared within a computation.
+ Status ClearComputationLayouts(HloComputation* computation);
+
+ // Clears the side effects of a previous pass, like added copy instructions.
+ Status ClearPreviousPassSideEffects(HloModule* module);
+
+ // Propagates the layouts computed by the layout assignment pass on the given
+ // computation, to the computation layout passed in to this API.
+ // This API propagates missing layout, and also checks that the caller
+ // specified have been respected, by comparing those with the parameters and
+ // root computation instruction.
+ Status PropagateComputationLayouts(HloComputation* computation,
+ ComputationLayout* computation_layout);
+
ComputationLayout* entry_computation_layout_;
protected:
@@ -418,21 +443,37 @@ class LayoutAssignment : public HloPassInterface {
// Creates and returns a copy of the given instruction with a different
// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
// instruction producing the copy is returned.
- static StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+ StatusOr<HloInstruction*> CreateCopyWithNewLayout(
const Shape& shape_with_layout, HloInstruction* instruction);
// Creates a copy of the given operand if the operand's layout does not match
// the given layout. This copy replaces the use in the given instruction.
// Tuple operands will be deep-copied.
- static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
- HloInstruction* instruction,
- int64 operand_no);
+ Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+ HloInstruction* instruction,
+ int64 operand_no);
+
+ // Registers a copy instruction added by the layout assignment pass.
+ void RegisterAddedCopy(HloInstruction* copy) {
+ CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
+ added_copies_.insert(copy);
+ }
+
+ // Adds a copy for the operand of an instruction, unless such operand is
+ // already a copy, and has a single user (which is forcibly the instruction
+ // itself).
+ Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
// Map containing the layouts of all computations assigned so
// far. Computations are handled in a topological sort where computations are
// handled before their caller instructions so the layouts of caller
// instructions can be set to match the computation.
std::map<HloComputation*, ComputationLayout> computation_layouts_;
+
+ // Every copy added to the module by the layout assignment pass is registered
+ // here.
+ tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
+
ChannelLayoutConstraints* channel_layout_constraints_;
};
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 39f3aefdf8..a73118c68a 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -308,7 +308,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
- computation_layout->mutable_result_layout()->Clear();
+ // TODO(b/78356948): We are forcing the default layout here. We should fix
+ // clients which expect a default layout, to be explicit about it, by
+ // passing the proper ExecutionOptions with shape_with_output_layout set.
+ computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
config->set_replica_count(options_.number_of_replicas());
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc
index 113c2e2bd9..d668855084 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc
@@ -69,6 +69,7 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Tuple
//
HloInstruction* top_tuple = nullptr;
+ HloInstruction* first_gte = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0;
operand_number < instruction->operand_count(); ++operand_number) {
@@ -78,11 +79,17 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
can_simplify = false;
break;
}
-
+ if (first_gte == nullptr) {
+ first_gte = operand;
+ } else if (!first_gte->has_compatible_sharding(operand)) {
+ can_simplify = false;
+ break;
+ }
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(),
- instruction->shape())) {
+ instruction->shape()) ||
+ !instruction->has_compatible_sharding(top_tuple)) {
can_simplify = false;
break;
}
@@ -108,15 +115,17 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// |
// GTE
if (instruction->operand(0)->opcode() == HloOpcode::kTuple) {
- changed = true;
HloInstruction* element_source =
instruction->mutable_operand(0)->mutable_operand(
instruction->tuple_index());
- TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
- for (HloInstruction* user : element_source->users()) {
- if (user->opcode() == HloOpcode::kTuple ||
- user->opcode() == HloOpcode::kGetTupleElement) {
- worklist.push(user);
+ if (instruction->has_compatible_sharding(element_source)) {
+ changed = true;
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
+ for (HloInstruction* user : element_source->users()) {
+ if (user->opcode() == HloOpcode::kTuple ||
+ user->opcode() == HloOpcode::kGetTupleElement) {
+ worklist.push(user);
+ }
}
}
}