aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-06-22 09:54:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 09:57:32 -0700
commit40ff081cd88efb73eee745f85c374efdb7c20542 (patch)
treefebf03da0be9a4d85a2f4a78dfab2bb2e51230f0 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parentb653d1ace6e1a7d9e063d1793b6cde36579426a2 (diff)
[XLA] Teach algebraic simplifier to convert a copy instruction to a bitcast
instruction. Extend ReshapeIsBitcast to handle copy instructions. This allows algebraic simplifier, for example, to replace a copy from shape f16[1,1,128,128]{3,2,1,0} to shape f16[1,1,128,128]{1,0,3,2} to bitcast for the CPU and GPU backends. Add a test case. PiperOrigin-RevId: 201699400
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc27
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 1fc8fb9b69..d8a9aba834 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -75,21 +75,22 @@ bool TransposeIsBitcast(const HloInstruction* transpose) {
transpose->dimensions());
}
-// Returns true if the given reshape produces a result which is bit-wise
+// Returns true if the given reshape/copy produces a result which is bit-wise
// identical to its operand and thus may be replaced with a bitcast.
//
// This function is conservative -- even if this function returns false, the
// reshape may still be a bitcast. For example, a reshape from [28x28] to [784].
-bool ReshapeIsBitcast(
- const HloInstruction* reshape,
+bool ReshapeOrCopyIsBitcast(
+ const HloInstruction* instr,
const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
- CHECK_EQ(HloOpcode::kReshape, reshape->opcode());
+ CHECK(HloOpcode::kReshape == instr->opcode() ||
+ HloOpcode::kCopy == instr->opcode());
- const HloInstruction* operand = reshape->operand(0);
+ const HloInstruction* operand = instr->operand(0);
// Can't insert bitcasts if the compiler used a memory layout which isn't
// compatible.
- return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) &&
- valid_bitcast_callback(operand->shape(), reshape->shape());
+ return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) &&
+ valid_bitcast_callback(operand->shape(), instr->shape());
}
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
@@ -433,7 +434,15 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
}
// All copies can be eliminated (assuming layout constraints are satisified).
- ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
+ if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
+ return Status::OK();
+ }
+
+ if (is_layout_sensitive_ &&
+ ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) {
+ ReplaceWithBitcast(copy);
+ }
+
return Status::OK();
}
@@ -1672,7 +1681,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
// Make this a bitcast if possible.
if (is_layout_sensitive_ &&
- ReshapeIsBitcast(reshape, valid_bitcast_callback_)) {
+ ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) {
ReplaceWithBitcast(reshape);
return Status::OK();
}