aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-07-16 13:29:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 13:33:05 -0700
commitbf87fa55a3280961bec8f31256016d362faf7c30 (patch)
tree67a32d9ffea4d0741fdd9b6a01f0d2da9aeefad0 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parent0c432e8e2eafc4320b55d0f083806982b0ccc218 (diff)
[XLA] Simplify A*0 and 0*A to 0 for integral types.
PiperOrigin-RevId: 204797049
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index af7728da54..2205a7ec18 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1156,6 +1156,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK();
}
+ // 0*A => 0. Only applies for integral types for correct NaN-handling.
+ if (IsAll(lhs, 0) &&
+ primitive_util::IsIntegralType(multiply->shape().element_type()) &&
+ ReplaceInstructionIfSameShape(multiply, lhs)) {
+ return Status::OK();
+ }
+ // A*0 => 0
+ if (IsAll(rhs, 0) &&
+ primitive_util::IsIntegralType(multiply->shape().element_type()) &&
+ ReplaceInstructionIfSameShape(multiply, rhs)) {
+ return Status::OK();
+ }
+
// exp(A) * exp(B) => exp(A+B)
if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
auto add = computation_->AddInstruction(HloInstruction::CreateBinary(