aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2018-06-22 17:24:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 17:29:59 -0700
commiteee35a70ce18568b9cd59378fb9b7f3c34d806d9 (patch)
treecd6f31548117096d8c6c644f08c8d11aa33a96b1 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parentb2b89083ae7f2da52ba1310f8224a46a9f64a437 (diff)
[XLA] Disallow implicit scalar broadcast.
PiperOrigin-RevId: 201765455
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc126
1 files changed, 29 insertions, 97 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index d8a9aba834..928aba913b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -50,20 +50,15 @@ namespace {
namespace m = match;
-// Returns whether operand is a literal with the given value.
-bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
- return operand->opcode() == HloOpcode::kConstant &&
- operand->literal().IsAll(value);
-}
-
bool IsAll(const HloInstruction* op, int8 value) {
- if (IsLiteralWithValue(op, value)) {
- return true;
- }
- if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) {
- return true;
+ switch (op->opcode()) {
+ case HloOpcode::kBroadcast:
+ return IsAll(op->operand(0), value);
+ case HloOpcode::kConstant:
+ return op->literal().IsAll(value);
+ default:
+ return false;
}
- return false;
}
// Returns whether the given transpose produces a result which is bit-wise
@@ -160,9 +155,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleMap(HloInstruction* map) override;
- Status HandleMaximum(HloInstruction* maximum) override;
- Status HandleMinimum(HloInstruction* minimum) override;
-
// Returns whether algebraic simplification has occurred.
const bool changed() const { return changed_; }
@@ -201,8 +193,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Helper method to perform and add reduction in a single dimension.
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
- HloInstruction* zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+ HloInstruction* zero =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ Literal::Zero(hlo->shape().element_type()).CloneToUnique()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -633,14 +626,16 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
- HloInstruction* one =
- computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(a->shape().element_type()).CloneToUnique()));
- HloInstruction* inverse = computation_->AddInstruction(
- HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kMultiply, a, inverse));
+ TF_ASSIGN_OR_RETURN(
+ auto big_one,
+ Literal::One(b->shape().element_type()).Broadcast(b->shape(), {}));
+ HloInstruction* one = computation_->AddInstruction(
+ HloInstruction::CreateConstant(std::move(big_one)));
+ TF_ASSIGN_OR_RETURN(auto inverse,
+ MakeBinaryHlo(HloOpcode::kDivide, one, b));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
+ return ReplaceInstruction(divide, new_divide);
}
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
@@ -660,18 +655,18 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
TF_ASSIGN_OR_RETURN(auto b_times_c,
MakeBinaryHlo(HloOpcode::kMultiply, b, c));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kDivide, a, b_times_c));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
+ return ReplaceInstruction(divide, new_divide);
}
// A / (B / C) => (A*C) / B
if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
TF_ASSIGN_OR_RETURN(auto a_times_c,
MakeBinaryHlo(HloOpcode::kMultiply, a, c));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kDivide, a_times_c, b));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
+ return ReplaceInstruction(divide, new_divide);
}
return Status::OK();
@@ -2074,10 +2069,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
convolution,
HloInstruction::CreateBroadcast(
convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::MakeShape(convolution->shape().element_type(), {}),
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ Literal::Zero(convolution->shape().element_type())
+ .CloneToUnique())),
{}));
}
const auto& window = convolution->window();
@@ -2249,68 +2243,6 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
return ReplaceWithNewInstruction(map, std::move(clone));
}
-Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
- // Match the following tree:
- // min_operand operand
- // \ /
- // max_operand min
- // \ /
- // max
- // where max_operand and min_operand are scalar constants.
- {
- HloInstruction* min;
- HloInstruction* max_operand;
- HloInstruction* min_operand;
- HloInstruction* operand;
-
- if (hlo_query::MatchBinaryInstructionOperandOpcode(
- HloOpcode::kMinimum, maximum,
- /*matching_operand=*/&min,
- /*other_operand=*/&max_operand) &&
- hlo_query::MatchBinaryInstructionOperand(
- hlo_query::IsScalarConstant, min,
- /*matching_operand=*/&min_operand,
- /*other_operand=*/&operand) &&
- TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum,
- max_operand)) {
- return Status::OK();
- }
- }
-
- return Status::OK();
-}
-
-Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
- // Match the following tree:
- // max_operand operand
- // \ /
- // min_operand max
- // \ /
- // min
- // where max_operand and min_operand are scalar constants.
- {
- HloInstruction* max;
- HloInstruction* max_operand;
- HloInstruction* min_operand;
- HloInstruction* operand;
-
- if (hlo_query::MatchBinaryInstructionOperandOpcode(
- HloOpcode::kMaximum, minimum,
- /*matching_operand=*/&max,
- /*other_operand=*/&min_operand) &&
- hlo_query::MatchBinaryInstructionOperand(
- hlo_query::IsScalarConstant, max,
- /*matching_operand=*/&max_operand,
- /*other_operand=*/&operand) &&
- TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max,
- max_operand)) {
- return Status::OK();
- }
- }
-
- return Status::OK();
-}
-
StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());