diff options
author | Peter Hawkins <phawkins@google.com> | 2018-07-16 13:29:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-16 13:33:05 -0700 |
commit | bf87fa55a3280961bec8f31256016d362faf7c30 (patch) | |
tree | 67a32d9ffea4d0741fdd9b6a01f0d2da9aeefad0 /tensorflow/compiler/xla/service/algebraic_simplifier.cc | |
parent | 0c432e8e2eafc4320b55d0f083806982b0ccc218 (diff) |
[XLA] Simplify A*0 and 0*A to 0 for integral types.
PiperOrigin-RevId: 204797049
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index af7728da54..2205a7ec18 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1156,6 +1156,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + // 0*A => 0. Only applies for integral types for correct NaN-handling. + if (IsAll(lhs, 0) && + primitive_util::IsIntegralType(multiply->shape().element_type()) && + ReplaceInstructionIfSameShape(multiply, lhs)) { + return Status::OK(); + } + // A*0 => 0 + if (IsAll(rhs, 0) && + primitive_util::IsIntegralType(multiply->shape().element_type()) && + ReplaceInstructionIfSameShape(multiply, rhs)) { + return Status::OK(); + } + // exp(A) * exp(B) => exp(A+B) if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( |