diff options
author | Bixia Zheng <bixia@google.com> | 2018-03-23 12:02:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-25 02:49:19 -0700 |
commit | d11dc5f7eefab6eb5590cee616bab8af2786e273 (patch) | |
tree | 3caf54e760814e05a880f7731177283876a3e6db /tensorflow/compiler/xla/service/reshape_mover.cc | |
parent | 1ece83d0cc98d09447e9a142605934328e18d15a (diff) |
[XLA] Allow reshape mover to move transpose across broadcast of a scalar value.
This allows the simplification of pattern "transpose elementwise-ops inversed
transpose" to "elementwise-ops".
Add a test case.
PiperOrigin-RevId: 190254501
Diffstat (limited to 'tensorflow/compiler/xla/service/reshape_mover.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/reshape_mover.cc | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index e62bafc50b..f15117f45c 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -53,6 +53,14 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) { instruction->opcode() == HloOpcode::kTranspose; } +// Returns true if `a` is a broadcast instruction to target shape `shape` and +// its operand is a scalar. +bool IsBroadcastScalarToShape(const HloInstruction* a, const Shape& shape) { + return a->opcode() == HloOpcode::kBroadcast && + ShapeUtil::SameDimensions(a->shape(), shape) && + ShapeUtil::IsScalar(a->operand(0)->shape()); +} + // Returns true iff `instruction` can change its shape simply by adjusting // metadata. bool CanTriviallyChangeShape(const HloInstruction* instruction) { @@ -88,6 +96,7 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { instruction->user_count() == 1) { return true; } + return false; } @@ -148,6 +157,8 @@ bool AllOperandsHaveEasyShapeChanges( // or // 2. Are one of kConstant, kRng, and scalars that can change shape // trivially, + // or + // 3. Are broadcast with a scalar operand. for (const HloInstruction* operand : instruction->operands()) { if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { VLOG(5) << "Operand shape differs from output shape; may be " @@ -158,6 +169,12 @@ bool AllOperandsHaveEasyShapeChanges( return false; } + // Skip the rest checks if the current operand is first_reshape_operand + // itself. + if (first_reshape_operand == operand) { + continue; + } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " << first_reshape_operand->ToString(print_no_metadata) @@ -171,6 +188,12 @@ bool AllOperandsHaveEasyShapeChanges( continue; } + if (IsBroadcastScalarToShape(operand, first_reshape_operand->shape())) { + VLOG(5) << "Broadcast scalar to shape: " + << operand->ToString(print_no_metadata); + continue; + } + // TODO(someone): Look into supporting general ops for the operands as // well. VLOG(5) << "Operand is neither equalivant to the first Reshape operand" @@ -222,6 +245,12 @@ HloInstruction* UpdateOperand(HloComputation* computation, VLOG(5) << "Using existing operand of kReshape or kTranspose"; return operand->mutable_operand(0); } + case HloOpcode::kBroadcast: + CHECK(IsBroadcastScalarToShape(operand, first_reshape_operand->shape())); + VLOG(5) << "Changing broadcast"; + return computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + default: LOG(FATAL) << "Unexpected operand opcode during update: " << operand; } |