diff options
author | Bixia Zheng <bixia@google.com> | 2018-06-22 09:54:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-22 09:57:32 -0700 |
commit | 40ff081cd88efb73eee745f85c374efdb7c20542 (patch) | |
tree | febf03da0be9a4d85a2f4a78dfab2bb2e51230f0 /tensorflow/compiler/xla/service/algebraic_simplifier.cc | |
parent | b653d1ace6e1a7d9e063d1793b6cde36579426a2 (diff) |
[XLA] Teach algebraic simplifier to convert a copy instruction to a bitcast
instruction.
Extend ReshapeIsBitcast to handle copy instructions. This allows algebraic
simplifier, for example, to replace a copy from shape f16[1,1,128,128]{3,2,1,0}
to shape f16[1,1,128,128]{1,0,3,2} to bitcast for the CPU and GPU backends.
Add a test case.
PiperOrigin-RevId: 201699400
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 1fc8fb9b69..d8a9aba834 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -75,21 +75,22 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { transpose->dimensions()); } -// Returns true if the given reshape produces a result which is bit-wise +// Returns true if the given reshape/copy produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. // // This function is conservative -- even if this function returns false, the // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. -bool ReshapeIsBitcast( - const HloInstruction* reshape, +bool ReshapeOrCopyIsBitcast( + const HloInstruction* instr, const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { - CHECK_EQ(HloOpcode::kReshape, reshape->opcode()); + CHECK(HloOpcode::kReshape == instr->opcode() || + HloOpcode::kCopy == instr->opcode()); - const HloInstruction* operand = reshape->operand(0); + const HloInstruction* operand = instr->operand(0); // Can't insert bitcasts if the compiler used a memory layout which isn't // compatible. - return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && - valid_bitcast_callback(operand->shape(), reshape->shape()); + return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) && + valid_bitcast_callback(operand->shape(), instr->shape()); } // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -433,7 +434,15 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); + if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) { + return Status::OK(); + } + + if (is_layout_sensitive_ && + ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + ReplaceWithBitcast(copy); + } + return Status::OK(); } @@ -1672,7 +1681,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Make this a bitcast if possible. if (is_layout_sensitive_ && - ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { + ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { ReplaceWithBitcast(reshape); return Status::OK(); } |