aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-10-08 14:26:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 14:34:02 -0700
commit396a8a4105edd409d0821c4d5d0b920b315ffb72 (patch)
tree428350d427ffb29470e284077a2734b273b7cc4d /tensorflow/compiler/xla/service/layout_assignment.cc
parentbc5635dc3ac78007caee88fabd81d23ad945b637 (diff)
Add custom call with layout constraints.
Add a variant of CustomCall which specifies arbitrary layout constraints on the operands and result. The existing non-layout-constrained CustomCall is changed to have no layout preference and can now be assigned arbitrary layouts by layout assignment. PiperOrigin-RevId: 216249615
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc108
1 files changed, 55 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index cc4a342e9d..ad65b147c1 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -419,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints(
return Status::OK();
}
+namespace {
+
+bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
+ return custom_call != nullptr && custom_call->layout_constrained();
+}
+
+} // namespace
+
Status LayoutAssignment::AddMandatoryConstraints(
const ComputationLayout* computation_layout,
ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
@@ -434,7 +444,6 @@ Status LayoutAssignment::AddMandatoryConstraints(
// Constrain layouts of instructions which define values with pre-existing
// layouts.
for (auto* instruction : computation->instructions()) {
- Shape const* shape_with_layout = nullptr;
if (instruction->opcode() == HloOpcode::kInfeed) {
// Infeed layouts must match the layout of the original inserted
// instruction.
@@ -456,17 +465,21 @@ Status LayoutAssignment::AddMandatoryConstraints(
if (parameter_layout.LayoutIsSet()) {
// Parameter layouts must match the respective layout in
// ComputationLayout, if there is one.
- shape_with_layout = &parameter_layout.shape();
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ parameter_layout.shape(), instruction));
}
}
- }
- if (shape_with_layout != nullptr) {
+ } else if (IsLayoutConstrainedCustomCall(instruction)) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
TF_RETURN_IF_ERROR(
- constraints->SetInstructionLayout(*shape_with_layout, instruction));
- }
-
- if (instruction->opcode() == HloOpcode::kSend ||
- instruction->opcode() == HloOpcode::kRecv) {
+ constraints->SetInstructionLayout(custom_call->shape(), custom_call));
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ custom_call->operand_shapes_with_layout()[i], custom_call, i));
+ }
+ } else if (instruction->opcode() == HloOpcode::kSend ||
+ instruction->opcode() == HloOpcode::kRecv) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = instruction->channel_id();
@@ -621,31 +634,6 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
false_computation_layout.parameter_shape(0), instruction, 2,
/*mandatory=*/true));
- } else if (instruction->opcode() == HloOpcode::kCustomCall) {
- if (!CustomCallRequiresMajorFirstLayout(instruction)) {
- continue;
- }
- // Add constraints for kCustomCall instruction operands and instructions.
- // For now we only support major-first layouts for all inputs and outputs.
- Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout(
- instruction->shape().element_type(),
- AsInt64Slice(instruction->shape().dimensions()));
- TF_RETURN_IF_ERROR(
- constraints->SetInstructionLayout(result_shape, instruction));
- for (int64 i = 0; i < instruction->operand_count(); ++i) {
- const Shape& operand_shape = instruction->operand(i)->shape();
- // Opaque operands don't get a layout constraint.
- if (ShapeUtil::IsOpaque(operand_shape)) {
- continue;
- }
-
- Shape row_major_operand_shape =
- ShapeUtil::MakeShapeWithDescendingLayout(
- operand_shape.element_type(),
- AsInt64Slice(operand_shape.dimensions()));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- row_major_operand_shape, instruction, i));
- }
}
}
// Finally set the result layout to match ComputationLayout, if there is one.
@@ -676,16 +664,18 @@ Status CheckCallLayout(HloInstruction* call,
return Status::OK();
}
-// Custom calls have fixed input and output layouts.
-Status CheckCustomCallLayout(HloInstruction* custom_call) {
- for (const HloInstruction* operand : custom_call->operands()) {
- TF_RET_CHECK(
- ShapeUtil::IsOpaque(operand->shape()) ||
- LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
+// Operands of layout-constrained custom calls must match the expected
+// constrained layouts.
+Status CheckCustomCallLayout(HloInstruction* instruction) {
+ if (IsLayoutConstrainedCustomCall(instruction)) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
+ custom_call->operand(i)->shape(),
+ custom_call->operand_shapes_with_layout()[i]));
+ }
}
- TF_RET_CHECK(
- ShapeUtil::IsOpaque(custom_call->shape()) ||
- LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout()));
return Status::OK();
}
@@ -932,9 +922,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, instruction->to_apply())));
break;
case HloOpcode::kCustomCall:
- if (CustomCallRequiresMajorFirstLayout(instruction)) {
- TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
- }
+ TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
break;
case HloOpcode::kFusion:
TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
@@ -1554,11 +1542,11 @@ Status LayoutAssignment::CalculateComputationLayout(
Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// Clear existing layouts of the instructions. All layouts must be assigned
- // by the LayoutAssignment pass, except for those on infeeds, parameters,
- // and the computation result. The latter two are specified in
- // computation_layout, so we only need to keep the existing layouts for
- // infeeds. Clearing the layouts here avoids hiding potential bugs in the
- // layout assignment pass that may accidentally use the existing layout.
+ // by the LayoutAssignment pass, except for those on parameters, the
+ // computation result, and a couple special cases. The former two are
+ // specified in computation_layout. Clearing the layouts here avoids hiding
+ // potential bugs in the layout assignment pass that may accidentally use the
+ // existing layout.
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kBitcast) {
// bitcasts are inherently layout sensitive and so a bitcast instruction
@@ -1567,7 +1555,9 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
"Unexpected bitcast operation seen during layout assignment: %s.",
instruction->ToString());
}
- if (instruction->opcode() != HloOpcode::kInfeed) {
+ // Some instructions carry mandatory layouts in their shape.
+ if (instruction->opcode() != HloOpcode::kInfeed &&
+ !IsLayoutConstrainedCustomCall(instruction)) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}
@@ -1802,6 +1792,18 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
}
TF_RETURN_IF_ERROR(Init());
+ // Verify computation layout is sane.
+ const HloComputation* entry = module->entry_computation();
+ TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
+ entry->num_parameters());
+ for (int64 i = 0; i < entry->num_parameters(); ++i) {
+ TF_RET_CHECK(
+ ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
+ entry->parameter_instruction(i)->shape()));
+ }
+ TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
+ entry->root_instruction()->shape()));
+
// We do two passes. The first one we pass a nullptr ComputationLayout to
// the RunOnComputation() calls (for non entry computations), and we register
// the ComputationLayout which are naturally flowing in DFS fashion to the
@@ -1873,7 +1875,6 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
- case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
@@ -1930,6 +1931,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kCopy:
+ case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kFusion: