aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc15
1 files changed, 6 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 2494569db5..cfa7ba5e81 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -909,22 +909,19 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
}
LayoutAssignment::LayoutAssignment(
- ComputationLayout* entry_computation_layout,
+ const ComputationLayout& entry_computation_layout,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
VLOG(1) << "entry computation layout given to layout assignment: "
- << entry_computation_layout_->ToString();
+ << entry_computation_layout_.ToString();
// Layouts of all parameter instructions must be set.
for (const ShapeLayout& parameter_layout :
- entry_computation_layout_->parameter_layouts()) {
+ entry_computation_layout_.parameter_layouts()) {
CHECK(parameter_layout.LayoutIsSet());
}
- // If the result layout is not set, then choose the default.
- // TODO(b/29118294): Choose a better layout in this case.
- if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
- entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
- }
+ // TODO(b/29118294): Choose a better layout if the result layout is not set.
+ CHECK(entry_computation_layout_.result_layout().LayoutIsSet());
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1597,7 +1594,7 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
}
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(
- *entry_computation_layout_, *points_to_analysis,
+ entry_computation_layout_, *points_to_analysis,
module->entry_computation(), channel_layout_constraints_));
} else {
ComputationLayout computation_layout(computation->ComputeProgramShape());