aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-17 20:37:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-17 20:40:48 -0700
commita8076b9450ca7873592c115841bdebf5f3febf52 (patch)
tree902bb6665f9e0525f2d986de5605cc215915e868 /tensorflow/compiler/xla/service/layout_assignment.cc
parentc696dcf24438fdb29394e776f1c865e0167cd368 (diff)
[XLA] Try to pass layouts through reshapes.
For reshapes where the operand and the output have the same rank, try to pass the layout through the reshape. The layout that's already present was presumably assigned for some reason, so it has a good chance of being good. PiperOrigin-RevId: 172555906
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc20
1 files changed, 18 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 2058706f11..7eda7c2284 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -732,7 +732,8 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
// dimension bound is 1 in the operand shape, there may be several such
// layouts. So if 'output_layout' is the default layout, try if the
// reshape is a bitcast when using the same layout. This may avoid copy
- // operations.
+ // operations. For similar reasons, if the operand and output have the same
+ // rank, try to match the operand's layout to the output.
if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
ShapeUtil::Rank(instruction->shape()) == 1) {
// Don't assign a layout in case of R1 -> effective R1 reshape.
@@ -748,6 +749,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) {
return MakeUnique<Layout>(operand_shape.layout());
}
+ if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) {
+ *operand_shape.mutable_layout() = output_layout;
+ if (ShapeUtil::ReshapeIsBitcast(operand_shape,
+ output_shape_with_layout)) {
+ return MakeUnique<Layout>(output_layout);
+ }
+ }
auto aligned_operand_shape =
ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
if (aligned_operand_shape) {
@@ -796,7 +804,8 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
// dimension bound is 1 in the user shape, there may be several such
// layouts. So if 'operand_layout' is the default layout, try if the
// reshape is a bitcast when using the same layout. This may avoid copy
- // operations.
+ // operations. For similar reasons, if the operand and output have the same
+ // rank, try to match the outputs's layout to the operand.
if (ShapeUtil::Rank(operand->shape()) == 1 &&
ShapeUtil::TrueRank(user->shape()) == 1) {
// Don't assign a layout in case of R1 -> effective R1 reshape.
@@ -812,6 +821,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) {
return MakeUnique<Layout>(output_shape.layout());
}
+ if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) {
+ *output_shape.mutable_layout() = operand_layout;
+ if (ShapeUtil::ReshapeIsBitcast(output_shape,
+ operand_shape_with_layout)) {
+ return MakeUnique<Layout>(operand_layout);
+ }
+ }
auto aligned_user_shape =
ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
if (aligned_user_shape) {