aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-13 21:15:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 21:18:50 -0700
commit6a581e1d7c28f5b8f487f2a91649d7e2866974f4 (patch)
tree725bdcde41e0ad3712332589daaef4503f62ef05 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parentcc9a8f789a4d224a3e73737fa6c921676441a6c8 (diff)
[XLA] Use pattern matcher in algebraic simplifier
PiperOrigin-RevId: 192862841
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc226
1 files changed, 107 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 6cb1bd5669..cd5737e4f9 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -44,8 +45,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
+
namespace {
+namespace m = match;
+
// Returns whether operand is a literal with the given value.
bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
return operand->opcode() == HloOpcode::kConstant &&
@@ -105,6 +109,7 @@ HloComputation* CreateScalarBinaryComputation(HloModule* module,
module->AddEmbeddedComputation(b.Build(scalar_op));
return scalar_computation;
}
+
} // namespace
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
@@ -350,8 +355,9 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
}
Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
- auto lhs = add->mutable_operand(0);
- auto rhs = add->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
+
// A + 0 => A
VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
@@ -366,7 +372,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
// Canonicalization: Put constants on the right. This makes the reassociation
// rules below simpler.
VLOG(10) << "trying transform [Const + A => A + Const]";
- if (lhs->IsConstant() && !rhs->IsConstant()) {
+ if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
return ReplaceWithNewInstruction(
add,
HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
@@ -379,16 +385,13 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
// (A + C1) + (B + C2) => A + B + (C1 + C2).
//
VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
- if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd &&
- !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) {
- auto* c1 = lhs->mutable_operand(1);
- auto* c2 = rhs;
-
+ HloInstruction *a, *c1, *c2;
+ if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
+ m::Constant(&c2)))) {
TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
return ReplaceWithNewInstruction(
- add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd,
- lhs->mutable_operand(0),
+ add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
sum_of_constants));
}
@@ -397,11 +400,11 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
// If a bitcast feeds a bitcast, make it a single bitcast.
- if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) {
+ HloInstruction* op;
+ if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
return ReplaceWithNewInstruction(
- bitcast, HloInstruction::CreateUnary(
- bitcast->shape(), HloOpcode::kBitcast,
- bitcast->mutable_operand(0)->mutable_operand(0)));
+ bitcast,
+ HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op));
}
// All bitcasts can be eliminated (assuming layout constraints are
// satisified).
@@ -418,11 +421,10 @@ Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
// If a copy feeds a copy, make it a single copy.
- if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
+ HloInstruction* op;
+ if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
return ReplaceWithNewInstruction(
- copy, HloInstruction::CreateUnary(
- copy->shape(), HloOpcode::kCopy,
- copy->mutable_operand(0)->mutable_operand(0)));
+ copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
}
// All copies can be eliminated (assuming layout constraints are satisified).
ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
@@ -462,12 +464,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
} else if (operands.size() == 2) {
// A binary concat with a broadcasted scalar as an operand can be converted
// into a pad which is simpler to fold into other operations.
- bool is_effective_low_pad =
- operands[0]->opcode() == HloOpcode::kBroadcast &&
- ShapeUtil::IsScalar(operands[0]->operand(0)->shape());
- bool is_effective_high_pad =
- operands[1]->opcode() == HloOpcode::kBroadcast &&
- ShapeUtil::IsScalar(operands[1]->operand(0)->shape());
+ bool is_effective_low_pad = Match(
+ operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
+ bool is_effective_high_pad = Match(
+ operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
if (!is_effective_low_pad && !is_effective_high_pad) {
return Status::OK();
}
@@ -537,8 +537,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
}
Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
- auto lhs = sub->mutable_operand(0);
- auto rhs = sub->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
// A - 0 => A
VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
@@ -547,7 +547,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
// Canonicalize subtraction of a constant to addition.
VLOG(10) << "trying transform [A - Const => A + (-Const)]";
- if (rhs->IsConstant() && !lhs->IsConstant()) {
+ if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) {
HloInstruction* negative_const = computation_->AddInstruction(
HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
return ReplaceWithNewInstruction(
@@ -559,56 +559,53 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
}
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
- auto lhs = divide->mutable_operand(0);
- auto rhs = divide->mutable_operand(1);
+ Shape* shape;
+ HloInstruction *a, *b, *c, *d;
+ CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
// A/1 => A
VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
- if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) {
+ if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
return Status::OK();
}
// exp(A)/exp(B) => exp(A-B)
- if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
+ if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
+ .WithShape(m::Shape(&shape)))) {
VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
- HloInstruction* subtract =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0),
- rhs->mutable_operand(0)));
+ HloInstruction* subtract = computation_->AddInstruction(
+ HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp,
- subtract));
+ divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
}
// A/exp(B) => A*exp(-B)
- if (rhs->opcode() == HloOpcode::kExp) {
+ if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
- HloInstruction* negate =
- computation_->AddInstruction(HloInstruction::CreateUnary(
- divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0)));
+ HloInstruction* negate = computation_->AddInstruction(
+ HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
HloInstruction* new_exp = computation_->AddInstruction(
HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kMultiply, lhs, new_exp));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kMultiply, a, new_exp));
}
// A/pow(B,C) => A*pow(B,-C)
- if (rhs->opcode() == HloOpcode::kPower) {
+ if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
// The output shape of the created negate operator should be the same as the
// input.
- const Shape& negate_shape = rhs->operand(1)->shape();
- HloInstruction* negate =
- computation_->AddInstruction(HloInstruction::CreateUnary(
- negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1)));
+ const Shape& negate_shape = c->shape();
+ HloInstruction* negate = computation_->AddInstruction(
+ HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
// And the power operator should retain the output shape of the old one.
- const Shape& new_power_shape = rhs->shape();
- HloInstruction* new_power = computation_->AddInstruction(
- HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower,
- rhs->mutable_operand(0), negate));
+ const Shape& new_power_shape = b->shape();
+ HloInstruction* new_power =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ new_power_shape, HloOpcode::kPower, b, negate));
return ReplaceWithNewInstruction(
divide, HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kMultiply, lhs, new_power));
+ divide->shape(), HloOpcode::kMultiply, a, new_power));
}
// Simplifying integral division would produce unexpected results.
@@ -620,28 +617,24 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
//
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
- if (lhs->opcode() != HloOpcode::kConstant &&
- rhs->opcode() == HloOpcode::kConstant) {
+ if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
HloInstruction* one =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(lhs->shape().element_type()).CloneToUnique()));
- HloInstruction* inverse =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- rhs->shape(), HloOpcode::kDivide, one, rhs));
+ Literal::One(a->shape().element_type()).CloneToUnique()));
+ HloInstruction* inverse = computation_->AddInstruction(
+ HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b));
return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(
- divide->shape(), HloOpcode::kMultiply, lhs, inverse));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kMultiply, a, inverse));
}
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
- if (lhs->opcode() == HloOpcode::kDivide &&
- rhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply,
- lhs->mutable_operand(0),
- rhs->mutable_operand(1)));
- TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply,
- lhs->mutable_operand(1),
- rhs->mutable_operand(0)));
+ if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
+ m::Divide(m::Op(&c), m::Op(&d))))) {
+ TF_ASSIGN_OR_RETURN(auto a_times_d,
+ MakeBinaryHlo(HloOpcode::kMultiply, a, d));
+ TF_ASSIGN_OR_RETURN(auto b_times_c,
+ MakeBinaryHlo(HloOpcode::kMultiply, b, c));
TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
a_times_d, b_times_c));
@@ -649,24 +642,21 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
}
// (A / B) / C => A / (B * C)
- if (lhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(
- auto b_times_c,
- MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
+ if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
+ TF_ASSIGN_OR_RETURN(auto b_times_c,
+ MakeBinaryHlo(HloOpcode::kMultiply, b, c));
return ReplaceWithNewInstruction(
- divide,
- HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
- lhs->mutable_operand(0), b_times_c));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kDivide, a, b_times_c));
}
// A / (B / C) => (A*C) / B
- if (rhs->opcode() == HloOpcode::kDivide) {
- TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs,
- rhs->mutable_operand(1)));
+ if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
+ TF_ASSIGN_OR_RETURN(auto a_times_c,
+ MakeBinaryHlo(HloOpcode::kMultiply, a, c));
return ReplaceWithNewInstruction(
- divide,
- HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
- a_times_c, rhs->mutable_operand(0)));
+ divide, HloInstruction::CreateBinary(divide->shape(),
+ HloOpcode::kDivide, a_times_c, b));
}
return Status::OK();
@@ -674,8 +664,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
HloInstruction* dot) {
- HloInstruction* lhs = dot->mutable_operand(0);
- HloInstruction* rhs = dot->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
int64 lhs_collapsing_dim =
dot->dot_dimension_numbers().lhs_contracting_dimensions(0);
if (lhs->IsRank2Transpose()) {
@@ -792,8 +782,8 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
- HloInstruction* lhs = dot->mutable_operand(0);
- HloInstruction* rhs = dot->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
TF_ASSIGN_OR_RETURN(
HloInstruction * optimized_lhs_concat,
@@ -923,8 +913,8 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
}
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
- auto lhs = dot->mutable_operand(0);
- auto rhs = dot->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
// Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
// below.
@@ -976,8 +966,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
}
Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
- auto lhs = multiply->mutable_operand(0);
- auto rhs = multiply->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
// A*1 => A
VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
@@ -990,10 +980,9 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
}
// exp(A) * exp(B) => exp(A+B)
- if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
+ if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
- multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0),
- rhs->mutable_operand(0)));
+ multiply->shape(), HloOpcode::kAdd, lhs, rhs));
return ReplaceWithNewInstruction(
multiply,
HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
@@ -1004,20 +993,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
// ln(exp(A)) => A
VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
- auto operand = log->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kExp &&
- ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) {
+ HloInstruction *a, *b;
+ if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
+ ReplaceInstructionIfSameShape(log, a)) {
return Status::OK();
}
// ln(pow(A,B)) => B*ln(A)
- if (operand->opcode() == HloOpcode::kPower) {
- auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary(
- log->shape(), HloOpcode::kLog, operand->mutable_operand(0)));
+ if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
+ auto new_log = computation_->AddInstruction(
+ HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
return ReplaceWithNewInstruction(
- log,
- HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
- new_log, operand->mutable_operand(1)));
+ log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
+ new_log, b));
}
return Status::OK();
@@ -1120,7 +1108,8 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
} // namespace
Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
- auto operand = broadcast->mutable_operand(0);
+ HloInstruction* operand;
+ CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
auto dims = broadcast->dimensions();
// A degenerate broadcast of a reshape that does not change the number of
// elements can be replaced by a reshape.
@@ -1231,30 +1220,28 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
// Complex(Real(c), Imag(c)) -> c
Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
- auto real = complex->mutable_operand(0);
- auto imag = complex->mutable_operand(1);
- if (real->opcode() == HloOpcode::kReal &&
- imag->opcode() == HloOpcode::kImag &&
- real->operand(0) == imag->operand(0)) {
- return ReplaceInstruction(complex, real->mutable_operand(0));
+ HloInstruction *c0, *c1;
+ if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
+ c0 == c1) {
+ return ReplaceInstruction(complex, c0);
}
return Status::OK();
}
// Real(Complex(r, i)) -> r
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
- auto operand = real->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kComplex) {
- return ReplaceInstruction(real, operand->mutable_operand(0));
+ HloInstruction* op;
+ if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
+ return ReplaceInstruction(real, op);
}
return Status::OK();
}
// Imag(Complex(r, i)) -> i
Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
- auto operand = imag->mutable_operand(0);
- if (operand->opcode() == HloOpcode::kComplex) {
- return ReplaceInstruction(imag, operand->mutable_operand(1));
+ HloInstruction* op;
+ if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
+ return ReplaceInstruction(imag, op);
}
return Status::OK();
}
@@ -1351,8 +1338,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
- auto lhs = power->mutable_operand(0);
- auto rhs = power->mutable_operand(1);
+ HloInstruction *lhs, *rhs;
+ CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
Literal::One(power->shape().element_type()).CloneToUnique());
@@ -1372,9 +1359,10 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
}
// pow(exp(A),B) => exp(A*B)
- if (lhs->opcode() == HloOpcode::kExp) {
+ HloInstruction *a, *b;
+ if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
- power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs));
+ power->shape(), HloOpcode::kMultiply, a, b));
return ReplaceWithNewInstruction(
power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
a_times_b));