aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-16 13:31:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 13:34:47 -0800
commit1bde8c61e14b1c42db20a1b07684ea2534cfdf01 (patch)
tree96b49f2ed4f6dc5054c3fa47707dd295ae670dd6 /tensorflow/compiler/xla/service/layout_assignment.cc
parent6e5ca37827d95b12c4712cf237ec2f8124ed885c (diff)
Automated g4 rollback of changelist 185623948
PiperOrigin-RevId: 186038783
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc212
1 files changed, 141 insertions, 71 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index fce135ef61..0668f66051 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -53,6 +53,83 @@ limitations under the License.
namespace xla {
+// For now moving only one API here, but we should have a single top level
+// anonymous namespace, instead of three or four spread all over this file.
+namespace {
+
+// Creates and returns a copy of the given instruction with a different
+// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
+// instruction producing the copy is returned.
+StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+ const Shape& shape_with_layout, HloInstruction* instruction) {
+ TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
+ DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
+ << ShapeUtil::HumanString(shape_with_layout) << " "
+ << ShapeUtil::HumanString(instruction->shape())
+ << " instruction: " << instruction->ToString();
+
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // Deep-copy tuples.
+ std::vector<HloInstruction*> element_copies;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
+ ++i) {
+ HloInstruction* gte = instruction->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
+ i));
+
+ // Recurse to copy each elements.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * element_copy,
+ CreateCopyWithNewLayout(
+ ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
+ element_copies.push_back(element_copy);
+ }
+ // Gather element copies into a tuple with a new Tuple instruction.
+ HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
+ HloInstruction::CreateTuple(element_copies));
+ LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, tuple_copy->mutable_shape()));
+ return tuple_copy;
+ } else if (ShapeUtil::IsArray(instruction->shape())) {
+ HloInstruction* copy =
+ instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
+ instruction->shape(), HloOpcode::kCopy, instruction));
+ LayoutUtil::ClearLayout(copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, copy->mutable_shape()));
+
+ return copy;
+ } else {
+ return FailedPrecondition(
+ "Can only copy array and tuple shaped instructions");
+ }
+}
+
+// Creates a copy of the given operand if the operand's layout does not match
+// the given layout. This copy replaces the use in the given instruction. Tuple
+// operands will be deep-copied.
+Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+ HloInstruction* instruction,
+ int64 operand_no) {
+ HloInstruction* operand = instruction->mutable_operand(operand_no);
+ TF_RET_CHECK(operand_layout.LayoutIsSet());
+ TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
+
+ if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
+ // Operand layout already matches our constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
+ CreateCopyWithNewLayout(operand_layout.shape(), operand));
+
+ return instruction->ReplaceOperandWith(operand_no, operand_copy);
+}
+
+} // namespace
+
std::ostream& operator<<(std::ostream& out,
const LayoutConstraint& constraint) {
out << constraint.ToString();
@@ -512,6 +589,36 @@ Status LayoutAssignment::AddMandatoryConstraints(
body_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
body_layout.result_shape(), instruction, 0));
+ } else if (instruction->opcode() == HloOpcode::kConditional) {
+ // The layout of the true and false computations must match, and must
+ // be the layout of the kConditional instruction.
+ TF_RET_CHECK(instruction->operand_count() == 3);
+
+ HloComputation* true_computation = instruction->true_computation();
+ HloComputation* false_computation = instruction->false_computation();
+ const HloInstruction* true_operand = instruction->operand(1);
+ const HloInstruction* false_operand = instruction->operand(2);
+
+ TF_RET_CHECK(true_computation->num_parameters() == 1);
+ TF_RET_CHECK(false_computation->num_parameters() == 1);
+ ComputationLayout& true_computation_layout =
+ FindOrDie(computation_layouts_, true_computation);
+ ComputationLayout& false_computation_layout =
+ FindOrDie(computation_layouts_, false_computation);
+
+ DCHECK(ShapeUtil::Compatible(true_operand->shape(),
+ true_computation_layout.parameter_shape(0)));
+ DCHECK(ShapeUtil::Compatible(
+ false_operand->shape(), false_computation_layout.parameter_shape(0)));
+
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ true_computation_layout.result_shape(), instruction));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ true_computation_layout.parameter_shape(0), instruction, 1,
+ /*mandatory=*/true));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ false_computation_layout.parameter_shape(0), instruction, 2,
+ /*mandatory=*/true));
} else if (instruction->opcode() == HloOpcode::kCustomCall) {
if (!CustomCallRequiresMajorFirstLayout(instruction)) {
continue;
@@ -598,6 +705,33 @@ Status CheckWhileLayout(HloInstruction* while_inst,
return Status::OK();
}
+Status CheckConditionalLayout(
+ HloInstruction* instruction,
+ const ComputationLayout& true_computation_layout,
+ const ComputationLayout& false_computation_layout) {
+ HloComputation* true_computation = instruction->true_computation();
+ HloComputation* false_computation = instruction->false_computation();
+ const HloInstruction* true_operand = instruction->operand(1);
+ const HloInstruction* false_operand = instruction->operand(2);
+
+ TF_RET_CHECK(true_computation_layout.result_layout() ==
+ false_computation_layout.result_layout());
+ TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
+ instruction->shape()));
+ TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
+ true_computation->root_instruction()->shape()));
+ TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
+ instruction->shape()));
+ TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
+ false_computation->root_instruction()->shape()));
+ TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ true_operand->shape()));
+ TF_RET_CHECK(
+ false_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ false_operand->shape()));
+ return Status::OK();
+}
+
// Fusion parameters must match the layout of the fusion instructions operands,
// and the root of the fusion expression must match the layout of the fusion
// instruction.
@@ -710,6 +844,13 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, instruction->while_condition()),
FindOrDie(computation_layouts_, instruction->while_body())));
break;
+ case HloOpcode::kConditional:
+ TF_RETURN_IF_ERROR(CheckConditionalLayout(
+ instruction,
+ FindOrDie(computation_layouts_, instruction->true_computation()),
+ FindOrDie(computation_layouts_,
+ instruction->false_computation())));
+ break;
default:
break;
}
@@ -1165,77 +1306,6 @@ StatusOr<Layout> InferArrayLayout(
return *first_buffer_layout;
}
-// Creates and returns a copy of the given instruction with a different
-// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
-// instruction producing the copy is returned.
-StatusOr<HloInstruction*> CreateCopyWithNewLayout(
- const Shape& shape_with_layout, HloInstruction* instruction) {
- TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
- DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
- << ShapeUtil::HumanString(shape_with_layout) << " "
- << ShapeUtil::HumanString(instruction->shape())
- << " instruction: " << instruction->ToString();
-
- if (ShapeUtil::IsTuple(instruction->shape())) {
- // Deep-copy tuples.
- std::vector<HloInstruction*> element_copies;
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
- ++i) {
- HloInstruction* gte = instruction->parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
- i));
-
- // Recurse to copy each elements.
- TF_ASSIGN_OR_RETURN(
- HloInstruction * element_copy,
- CreateCopyWithNewLayout(
- ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
- element_copies.push_back(element_copy);
- }
- // Gather element copies into a tuple with a new Tuple instruction.
- HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
- HloInstruction::CreateTuple(element_copies));
- LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
- TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
- shape_with_layout, tuple_copy->mutable_shape()));
- return tuple_copy;
- } else if (ShapeUtil::IsArray(instruction->shape())) {
- HloInstruction* copy =
- instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
- instruction->shape(), HloOpcode::kCopy, instruction));
- LayoutUtil::ClearLayout(copy->mutable_shape());
- TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
- shape_with_layout, copy->mutable_shape()));
-
- return copy;
- } else {
- return FailedPrecondition(
- "Can only copy array and tuple shaped instructions");
- }
-}
-
-// Creates a copy of the given operand if the operand's layout does not match
-// the given layout. This copy replaces the use in the given instruction. Tuple
-// operands will be deep-copied.
-Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
- HloInstruction* instruction,
- int64 operand_no) {
- HloInstruction* operand = instruction->mutable_operand(operand_no);
- TF_RET_CHECK(operand_layout.LayoutIsSet());
- TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
-
- if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
- // Operand layout already matches our constraint. Nothing to do.
- return Status::OK();
- }
-
- TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
- CreateCopyWithNewLayout(operand_layout.shape(), operand));
-
- return instruction->ReplaceOperandWith(operand_no, operand_copy);
-}
-
// For fusion instructions, set the layout of each fused parameter instruction
// to match the layout of its corresponding fusion instruction operand. Also,
// set the layout of the fused root to match the layout of the fusion