aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/reshape_mover.cc
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2017-05-23 17:49:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-23 17:53:45 -0700
commitad013db8a9243f20766d7226e82e2cb79fe721c0 (patch)
tree3e323f009c250418a54b7a37004b8c77f90f493b /tensorflow/compiler/xla/service/reshape_mover.cc
parentafcd75baa433f22447175d63091261dcda6209e3 (diff)
[XLA] In ReshapeMover: only sink Reshape/Transpose below elementwise HLO, if the
number of non-trivial reshape/transpose operands is non-zero. (Otherwise there's no benefit to sink Reshape/Transpose). PiperOrigin-RevId: 156936778
Diffstat (limited to 'tensorflow/compiler/xla/service/reshape_mover.cc')
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc109
1 files changed, 70 insertions, 39 deletions
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index 768977ba6b..9a788413df 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -64,7 +64,7 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
return nullptr;
}
-// Check if an operand of an instruction can change its shape simply by
+// Checks if an operand of an instruction can change its shape simply by
// adjusting metadata. This is the case if an operand does not have any
// producers like Constants or Rng instruction, or is a scalar.
bool OperandCanTrivallyChangeShape(const HloInstruction* instruction,
@@ -129,43 +129,73 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes(
VLOG(3) << "** Checking whether instruction is an elementwise operation of "
"equivalent reshapes/transposes: "
<< instruction->ToStringNoMetadata();
- bool result =
- (instruction->user_count() > 0 ||
- instruction == instruction->parent()->root_instruction()) &&
- instruction->IsElementwise() && !operands.empty() &&
- // Check whether all operands:
- // 0. Have the same dimensions as the output -- if not, it may be
- // implicitly broadcast, which can confound the movement's
- // correctness.
- // 1. Are all reshapes or transposes that have the same input and
- // output shapes as all other reshaped or transposed operands.
- // or
- // 2. Can be any shape like kConstant, kRng, and scalars.
- std::all_of(
- operands.begin(), operands.end(),
- [instruction, first_reshape_operand](const HloInstruction* operand) {
- if (!ShapeUtil::SameDimensions(operand->shape(),
- instruction->shape())) {
- VLOG(5) << "Operand shape differs from output shape; may be "
- "implicitly broadcast, so preventing "
- "movement\n\toperand: "
- << operand->ToStringNoMetadata() << "\n\tinstruction: "
- << instruction->ToStringNoMetadata();
- return false;
- }
- if (AreEquivalentReshapes(first_reshape_operand, operand)) {
- VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
- << first_reshape_operand->ToStringNoMetadata()
- << "\n\toperand: " << operand->ToStringNoMetadata();
- return true;
- }
- if (OperandCanTrivallyChangeShape(instruction, operand)) {
- VLOG(5) << "Operand can trivially change shape: "
- << operand->ToStringNoMetadata();
- return true;
- }
- return false;
- });
+ bool result = (instruction->user_count() > 0 ||
+ instruction == instruction->parent()->root_instruction()) &&
+ instruction->IsElementwise() && !operands.empty();
+
+ // Check whether all operands:
+ // 0. Have the same dimensions as the output -- if not, it may be
+ // implicitly broadcast, which can confound the movement's
+ // correctness.
+ //
+ // And one of the following:
+ // 1. Are reshapes or transposes that have the same input and
+ // output shapes as all other reshaped or transposed operands.
+ // or
+ // 2. Are one of kConstant, kRng, and scalars that can change shape
+ // trivially,
+ //
+ // And:
+ // 3. The number of operands of kReshape/kTranspose type are greater than
+ // the 1, and their associated operands are not constant. In other words,
+ // the number of eliminable non-trivial reshapes is greater than 1.
+ if (result) {
+ int nontrivial_reshape_operands = 0;
+ for (auto& operand : operands) {
+ if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
+ VLOG(5) << "Operand shape differs from output shape; may be "
+ "implicitly broadcast, so preventing "
+ "movement\n\toperand: "
+ << operand->ToStringNoMetadata()
+ << "\n\tinstruction: " << instruction->ToStringNoMetadata();
+ result = false;
+ break;
+ }
+
+ if (AreEquivalentReshapes(first_reshape_operand, operand)) {
+ VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
+ << first_reshape_operand->ToStringNoMetadata()
+ << "\n\toperand: " << operand->ToStringNoMetadata();
+ CHECK_EQ(operand->operand_count(), 1);
+ if (!OperandCanTrivallyChangeShape(operand, operand->operand(0))) {
+ VLOG(5) << "Reshape/Transpose is nontrivial because its operand "
+ "cannot trivially change shape: "
+ << operand->ToStringNoMetadata();
+ nontrivial_reshape_operands++;
+ }
+ continue;
+ }
+
+ if (OperandCanTrivallyChangeShape(instruction, operand)) {
+ VLOG(5) << "Operand can trivially change shape: "
+ << operand->ToStringNoMetadata();
+ continue;
+ }
+
+ // TODO(someone): Look into supporting general ops for the operands as
+ // well.
+ VLOG(5) << "Operand is neither equalivant to the first Reshape operand"
+ "nor can trivially change shape: "
+ << operand->ToStringNoMetadata();
+ result = false;
+ break;
+ }
+ if (nontrivial_reshape_operands == 0) {
+ VLOG(5) << "No eliminable and non-trivial reshapes found.";
+ result = false;
+ }
+ }
+
VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for "
<< instruction->ToStringNoMetadata() << ": " << result;
return result;
@@ -180,7 +210,6 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
return false;
}
- std::vector<HloInstruction*> operands = instruction->operands();
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
TF_RET_CHECK(old_reshape != nullptr);
Shape new_elementwise_shape = old_reshape->operand(0)->shape();
@@ -190,6 +219,8 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
<< "\n\told reshape: " << old_reshape->ToStringNoMetadata()
<< "\n\tnew elementwise shape: "
<< ShapeUtil::HumanString(new_elementwise_shape);
+
+ std::vector<HloInstruction*> operands = instruction->operands();
for (size_t i = 0; i < operands.size(); ++i) {
// All scalar operands remain as-is, even if they're reshape or transpose,
// to simplify handling wrt special scalar broadcast rules for ops like