diff options
author | Kay Zhu <kayzhu@google.com> | 2017-05-23 17:49:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-23 17:53:45 -0700 |
commit | ad013db8a9243f20766d7226e82e2cb79fe721c0 (patch) | |
tree | 3e323f009c250418a54b7a37004b8c77f90f493b /tensorflow/compiler/xla/service/reshape_mover.cc | |
parent | afcd75baa433f22447175d63091261dcda6209e3 (diff) |
[XLA] In ReshapeMover: only sink Reshape/Transpose below elementwise HLO, if the
number of non-trivial reshape/transpose operands is non-zero. (Otherwise there's
no benefit to sink Reshape/Transpose).
PiperOrigin-RevId: 156936778
Diffstat (limited to 'tensorflow/compiler/xla/service/reshape_mover.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/reshape_mover.cc | 109 |
1 files changed, 70 insertions, 39 deletions
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 768977ba6b..9a788413df 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -64,7 +64,7 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { return nullptr; } -// Check if an operand of an instruction can change its shape simply by +// Checks if an operand of an instruction can change its shape simply by // adjusting metadata. This is the case if an operand does not have any // producers like Constants or Rng instruction, or is a scalar. bool OperandCanTrivallyChangeShape(const HloInstruction* instruction, @@ -129,43 +129,73 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( VLOG(3) << "** Checking whether instruction is an elementwise operation of " "equivalent reshapes/transposes: " << instruction->ToStringNoMetadata(); - bool result = - (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty() && - // Check whether all operands: - // 0. Have the same dimensions as the output -- if not, it may be - // implicitly broadcast, which can confound the movement's - // correctness. - // 1. Are all reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. Can be any shape like kConstant, kRng, and scalars. - std::all_of( - operands.begin(), operands.end(), - [instruction, first_reshape_operand](const HloInstruction* operand) { - if (!ShapeUtil::SameDimensions(operand->shape(), - instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToStringNoMetadata() << "\n\tinstruction: " - << instruction->ToStringNoMetadata(); - return false; - } - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToStringNoMetadata() - << "\n\toperand: " << operand->ToStringNoMetadata(); - return true; - } - if (OperandCanTrivallyChangeShape(instruction, operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToStringNoMetadata(); - return true; - } - return false; - }); + bool result = (instruction->user_count() > 0 || + instruction == instruction->parent()->root_instruction()) && + instruction->IsElementwise() && !operands.empty(); + + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, it may be + // implicitly broadcast, which can confound the movement's + // correctness. + // + // And one of the following: + // 1. Are reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Are one of kConstant, kRng, and scalars that can change shape + // trivially, + // + // And: + // 3. The number of operands of kReshape/kTranspose type are greater than + // the 1, and their associated operands are not constant. In other words, + // the number of eliminable non-trivial reshapes is greater than 1. + if (result) { + int nontrivial_reshape_operands = 0; + for (auto& operand : operands) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + result = false; + break; + } + + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + CHECK_EQ(operand->operand_count(), 1); + if (!OperandCanTrivallyChangeShape(operand, operand->operand(0))) { + VLOG(5) << "Reshape/Transpose is nontrivial because its operand " + "cannot trivially change shape: " + << operand->ToStringNoMetadata(); + nontrivial_reshape_operands++; + } + continue; + } + + if (OperandCanTrivallyChangeShape(instruction, operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToStringNoMetadata(); + continue; + } + + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is neither equalivant to the first Reshape operand" + "nor can trivially change shape: " + << operand->ToStringNoMetadata(); + result = false; + break; + } + if (nontrivial_reshape_operands == 0) { + VLOG(5) << "No eliminable and non-trivial reshapes found."; + result = false; + } + } + VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " << instruction->ToStringNoMetadata() << ": " << result; return result; @@ -180,7 +210,6 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, return false; } - std::vector<HloInstruction*> operands = instruction->operands(); HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); TF_RET_CHECK(old_reshape != nullptr); Shape new_elementwise_shape = old_reshape->operand(0)->shape(); @@ -190,6 +219,8 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, << "\n\told reshape: " << old_reshape->ToStringNoMetadata() << "\n\tnew elementwise shape: " << ShapeUtil::HumanString(new_elementwise_shape); + + std::vector<HloInstruction*> operands = instruction->operands(); for (size_t i = 0; i < operands.size(); ++i) { // All scalar operands remain as-is, even if they're reshape or transpose, // to simplify handling wrt special scalar broadcast rules for ops like |