diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-13 21:15:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-13 21:18:50 -0700 |
commit | 6a581e1d7c28f5b8f487f2a91649d7e2866974f4 (patch) | |
tree | 725bdcde41e0ad3712332589daaef4503f62ef05 /tensorflow/compiler/xla/service/algebraic_simplifier.cc | |
parent | cc9a8f789a4d224a3e73737fa6c921676441a6c8 (diff) |
[XLA] Use pattern matcher in algebraic simplifier
PiperOrigin-RevId: 192862841
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 226 |
1 files changed, 107 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 6cb1bd5669..cd5737e4f9 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -44,8 +45,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { + namespace { +namespace m = match; + // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -105,6 +109,7 @@ HloComputation* CreateScalarBinaryComputation(HloModule* module, module->AddEmbeddedComputation(b.Build(scalar_op)); return scalar_computation; } + } // namespace // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -350,8 +355,9 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( } Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { - auto lhs = add->mutable_operand(0); - auto rhs = add->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs)))); + // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { @@ -366,7 +372,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // Canonicalization: Put constants on the right. This makes the reassociation // rules below simpler. VLOG(10) << "trying transform [Const + A => A + Const]"; - if (lhs->IsConstant() && !rhs->IsConstant()) { + if (Match(add, m::Add(m::Constant(), m::NonConstant()))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); @@ -379,16 +385,13 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // (A + C1) + (B + C2) => A + B + (C1 + C2). // VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; - if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd && - !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { - auto* c1 = lhs->mutable_operand(1); - auto* c2 = rhs; - + HloInstruction *a, *c1, *c2; + if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2)))) { TF_ASSIGN_OR_RETURN(auto* sum_of_constants, MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); return ReplaceWithNewInstruction( - add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, - lhs->mutable_operand(0), + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, sum_of_constants)); } @@ -397,11 +400,11 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { // If a bitcast feeds a bitcast, make it a single bitcast. - if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) { + HloInstruction* op; + if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { return ReplaceWithNewInstruction( - bitcast, HloInstruction::CreateUnary( - bitcast->shape(), HloOpcode::kBitcast, - bitcast->mutable_operand(0)->mutable_operand(0))); + bitcast, + HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op)); } // All bitcasts can be eliminated (assuming layout constraints are // satisified). @@ -418,11 +421,10 @@ Status AlgebraicSimplifierVisitor::HandleBitcastConvert( Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (copy->operand(0)->opcode() == HloOpcode::kCopy) { + HloInstruction* op; + if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary( - copy->shape(), HloOpcode::kCopy, - copy->mutable_operand(0)->mutable_operand(0))); + copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); @@ -462,12 +464,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } else if (operands.size() == 2) { // A binary concat with a broadcasted scalar as an operand can be converted // into a pad which is simpler to fold into other operations. - bool is_effective_low_pad = - operands[0]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[0]->operand(0)->shape()); - bool is_effective_high_pad = - operands[1]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[1]->operand(0)->shape()); + bool is_effective_low_pad = Match( + operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); + bool is_effective_high_pad = Match( + operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); if (!is_effective_low_pad && !is_effective_high_pad) { return Status::OK(); } @@ -537,8 +537,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { } Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { - auto lhs = sub->mutable_operand(0); - auto rhs = sub->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs)))); // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { @@ -547,7 +547,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { // Canonicalize subtraction of a constant to addition. VLOG(10) << "trying transform [A - Const => A + (-Const)]"; - if (rhs->IsConstant() && !lhs->IsConstant()) { + if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) { HloInstruction* negative_const = computation_->AddInstruction( HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); return ReplaceWithNewInstruction( @@ -559,56 +559,53 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { } Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { - auto lhs = divide->mutable_operand(0); - auto rhs = divide->mutable_operand(1); + Shape* shape; + HloInstruction *a, *b, *c, *d; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) { return Status::OK(); } // exp(A)/exp(B) => exp(A-B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) + .WithShape(m::Shape(&shape)))) { VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); - HloInstruction* subtract = - computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + HloInstruction* subtract = computation_->AddInstruction( + HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, - subtract)); + divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract)); } // A/exp(B) => A*exp(-B) - if (rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) { VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0))); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b)); HloInstruction* new_exp = computation_->AddInstruction( HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_exp)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, new_exp)); } // A/pow(B,C) => A*pow(B,-C) - if (rhs->opcode() == HloOpcode::kPower) { + if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) { VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); // The output shape of the created negate operator should be the same as the // input. - const Shape& negate_shape = rhs->operand(1)->shape(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1))); + const Shape& negate_shape = c->shape(); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c)); // And the power operator should retain the output shape of the old one. - const Shape& new_power_shape = rhs->shape(); - HloInstruction* new_power = computation_->AddInstruction( - HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower, - rhs->mutable_operand(0), negate)); + const Shape& new_power_shape = b->shape(); + HloInstruction* new_power = + computation_->AddInstruction(HloInstruction::CreateBinary( + new_power_shape, HloOpcode::kPower, b, negate)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_power)); + divide->shape(), HloOpcode::kMultiply, a, new_power)); } // Simplifying integral division would produce unexpected results. @@ -620,28 +617,24 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // // (Backends can do this transformation, but generally only if the constant is // a scalar.) - if (lhs->opcode() != HloOpcode::kConstant && - rhs->opcode() == HloOpcode::kConstant) { + if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { HloInstruction* one = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(lhs->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = - computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kDivide, one, rhs)); + Literal::One(a->shape().element_type()).CloneToUnique())); + HloInstruction* inverse = computation_->AddInstruction( + HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, inverse)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, inverse)); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) - if (lhs->opcode() == HloOpcode::kDivide && - rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply, - lhs->mutable_operand(0), - rhs->mutable_operand(1))); - TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, - lhs->mutable_operand(1), - rhs->mutable_operand(0))); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), + m::Divide(m::Op(&c), m::Op(&d))))) { + TF_ASSIGN_OR_RETURN(auto a_times_d, + MakeBinaryHlo(HloOpcode::kMultiply, a, d)); + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, a_times_d, b_times_c)); @@ -649,24 +642,21 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { } // (A / B) / C => A / (B * C) - if (lhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN( - auto b_times_c, - MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - lhs->mutable_operand(0), b_times_c)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a, b_times_c)); } // A / (B / C) => (A*C) / B - if (rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs, - rhs->mutable_operand(1))); + if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { + TF_ASSIGN_OR_RETURN(auto a_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, a, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - a_times_c, rhs->mutable_operand(0))); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a_times_c, b)); } return Status::OK(); @@ -674,8 +664,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); int64 lhs_collapsing_dim = dot->dot_dimension_numbers().lhs_contracting_dimensions(0); if (lhs->IsRank2Transpose()) { @@ -792,8 +782,8 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat( const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, @@ -923,8 +913,8 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( } Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or // below. @@ -976,8 +966,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { - auto lhs = multiply->mutable_operand(0); - auto rhs = multiply->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { @@ -990,10 +980,9 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { } // exp(A) * exp(B) => exp(A+B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( - multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + multiply->shape(), HloOpcode::kAdd, lhs, rhs)); return ReplaceWithNewInstruction( multiply, HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); @@ -1004,20 +993,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); - auto operand = log->mutable_operand(0); - if (operand->opcode() == HloOpcode::kExp && - ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { + HloInstruction *a, *b; + if (Match(log, m::Log(m::Exp(m::Op(&a)))) && + ReplaceInstructionIfSameShape(log, a)) { return Status::OK(); } // ln(pow(A,B)) => B*ln(A) - if (operand->opcode() == HloOpcode::kPower) { - auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary( - log->shape(), HloOpcode::kLog, operand->mutable_operand(0))); + if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) { + auto new_log = computation_->AddInstruction( + HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); return ReplaceWithNewInstruction( - log, - HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, - new_log, operand->mutable_operand(1))); + log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, b)); } return Status::OK(); @@ -1120,7 +1108,8 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, } // namespace Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { - auto operand = broadcast->mutable_operand(0); + HloInstruction* operand; + CHECK(Match(broadcast, m::Broadcast(m::Op(&operand)))); auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. @@ -1231,30 +1220,28 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { // Complex(Real(c), Imag(c)) -> c Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { - auto real = complex->mutable_operand(0); - auto imag = complex->mutable_operand(1); - if (real->opcode() == HloOpcode::kReal && - imag->opcode() == HloOpcode::kImag && - real->operand(0) == imag->operand(0)) { - return ReplaceInstruction(complex, real->mutable_operand(0)); + HloInstruction *c0, *c1; + if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) && + c0 == c1) { + return ReplaceInstruction(complex, c0); } return Status::OK(); } // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { - auto operand = real->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(real, operand->mutable_operand(0)); + HloInstruction* op; + if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) { + return ReplaceInstruction(real, op); } return Status::OK(); } // Imag(Complex(r, i)) -> i Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { - auto operand = imag->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(imag, operand->mutable_operand(1)); + HloInstruction* op; + if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) { + return ReplaceInstruction(imag, op); } return Status::OK(); } @@ -1351,8 +1338,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); - auto lhs = power->mutable_operand(0); - auto rhs = power->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( Literal::One(power->shape().element_type()).CloneToUnique()); @@ -1372,9 +1359,10 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { } // pow(exp(A),B) => exp(A*B) - if (lhs->opcode() == HloOpcode::kExp) { + HloInstruction *a, *b; + if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) { auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( - power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs)); + power->shape(), HloOpcode::kMultiply, a, b)); return ReplaceWithNewInstruction( power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, a_times_b)); |