aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/reshape_mover.cc
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-03-23 12:02:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 02:49:19 -0700
commitd11dc5f7eefab6eb5590cee616bab8af2786e273 (patch)
tree3caf54e760814e05a880f7731177283876a3e6db /tensorflow/compiler/xla/service/reshape_mover.cc
parent1ece83d0cc98d09447e9a142605934328e18d15a (diff)
[XLA] Allow reshape mover to move transpose across broadcast of a scalar value.
This allows the simplification of pattern "transpose elementwise-ops inversed transpose" to "elementwise-ops". Add a test case. PiperOrigin-RevId: 190254501
Diffstat (limited to 'tensorflow/compiler/xla/service/reshape_mover.cc')
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index e62bafc50b..f15117f45c 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -53,6 +53,14 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) {
instruction->opcode() == HloOpcode::kTranspose;
}
+// Returns true if `a` is a broadcast instruction to target shape `shape` and
+// its operand is a scalar.
+bool IsBroadcastScalarToShape(const HloInstruction* a, const Shape& shape) {
+ return a->opcode() == HloOpcode::kBroadcast &&
+ ShapeUtil::SameDimensions(a->shape(), shape) &&
+ ShapeUtil::IsScalar(a->operand(0)->shape());
+}
+
// Returns true iff `instruction` can change its shape simply by adjusting
// metadata.
bool CanTriviallyChangeShape(const HloInstruction* instruction) {
@@ -88,6 +96,7 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) {
instruction->user_count() == 1) {
return true;
}
+
return false;
}
@@ -148,6 +157,8 @@ bool AllOperandsHaveEasyShapeChanges(
// or
// 2. Are one of kConstant, kRng, and scalars that can change shape
// trivially,
+ // or
+ // 3. Are broadcast with a scalar operand.
for (const HloInstruction* operand : instruction->operands()) {
if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
VLOG(5) << "Operand shape differs from output shape; may be "
@@ -158,6 +169,12 @@ bool AllOperandsHaveEasyShapeChanges(
return false;
}
+ // Skip the rest checks if the current operand is first_reshape_operand
+ // itself.
+ if (first_reshape_operand == operand) {
+ continue;
+ }
+
if (AreEquivalentReshapes(first_reshape_operand, operand)) {
VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
<< first_reshape_operand->ToString(print_no_metadata)
@@ -171,6 +188,12 @@ bool AllOperandsHaveEasyShapeChanges(
continue;
}
+ if (IsBroadcastScalarToShape(operand, first_reshape_operand->shape())) {
+ VLOG(5) << "Broadcast scalar to shape: "
+ << operand->ToString(print_no_metadata);
+ continue;
+ }
+
// TODO(someone): Look into supporting general ops for the operands as
// well.
VLOG(5) << "Operand is neither equalivant to the first Reshape operand"
@@ -222,6 +245,12 @@ HloInstruction* UpdateOperand(HloComputation* computation,
VLOG(5) << "Using existing operand of kReshape or kTranspose";
return operand->mutable_operand(0);
}
+ case HloOpcode::kBroadcast:
+ CHECK(IsBroadcastScalarToShape(operand, first_reshape_operand->shape()));
+ VLOG(5) << "Changing broadcast";
+ return computation->AddInstruction(
+ operand->CloneWithNewOperands(new_shape, operand->operands()));
+
default:
LOG(FATAL) << "Unexpected operand opcode during update: " << operand;
}