aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc93
1 files changed, 68 insertions, 25 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 48fd07371d..505c0e8dff 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -195,7 +196,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::Zero(hlo->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -537,8 +538,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar =
- MakeUnique<Literal>(constant->literal().GetFirstScalarLiteral());
+ std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>(
+ LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
return ReplaceWithNewInstruction(
@@ -1093,7 +1094,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
ShapeUtil::IsZeroElementArray(lhs->shape()) ||
ShapeUtil::IsZeroElementArray(rhs->shape())) {
auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}
@@ -1155,6 +1156,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK();
}
+ // 0*A => 0. Only applies for integral types for correct NaN-handling.
+ if (IsAll(lhs, 0) &&
+ primitive_util::IsIntegralType(multiply->shape().element_type()) &&
+ ReplaceInstructionIfSameShape(multiply, lhs)) {
+ return Status::OK();
+ }
+ // A*0 => 0
+ if (IsAll(rhs, 0) &&
+ primitive_util::IsIntegralType(multiply->shape().element_type()) &&
+ ReplaceInstructionIfSameShape(multiply, rhs)) {
+ return Status::OK();
+ }
+
// exp(A) * exp(B) => exp(A+B)
if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -1252,9 +1266,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
switch (instruction->opcode()) {
case HloOpcode::kReshape:
case HloOpcode::kReverse:
- case HloOpcode::kSort:
case HloOpcode::kTranspose:
return true;
+ case HloOpcode::kSort:
+ return (!ShapeUtil::IsTuple(instruction->shape()));
default:
return false;
}
@@ -1518,7 +1533,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- Literal::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).CloneToUnique());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1553,7 +1568,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -1729,19 +1744,37 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
return Status::OK();
}
+
+ auto is_unstrided_slice = [](const HloInstruction* hlo) {
+ return c_all_of(hlo->slice_strides(),
+ [](int64 stride) { return stride == 1; });
+ };
+ if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
+ is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) {
+ HloInstruction* operand_slice = slice->mutable_operand(0);
+ std::vector<int64> new_slice_starts = slice->slice_starts();
+ std::vector<int64> new_slice_limits = slice->slice_limits();
+ for (int64 i = 0; i < new_slice_starts.size(); ++i) {
+ new_slice_starts[i] += operand_slice->slice_starts(i);
+ new_slice_limits[i] += operand_slice->slice_starts(i);
+ }
+ return ReplaceWithNewInstruction(
+ slice, HloInstruction::CreateSlice(
+ slice->shape(), operand_slice->mutable_operand(0),
+ new_slice_starts, new_slice_limits, slice->slice_strides()));
+ }
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
HloInstruction* dynamic_slice) {
auto operand = dynamic_slice->mutable_operand(0);
- auto start_indices = dynamic_slice->operand(1);
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
return ReplaceInstruction(dynamic_slice, operand);
}
- // DynamicSlice where operand has the same size as the output and
- // start_indices are all zero is simply equal to operand.
- if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) {
+ // DynamicSlice where operand has the same size as the output is simply equal
+ // to operand.
+ if (SameShape(operand, dynamic_slice)) {
return ReplaceInstruction(dynamic_slice, operand);
}
return Status::OK();
@@ -1750,20 +1783,10 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) {
auto update = dynamic_update_slice->mutable_operand(1);
- auto start_indices = dynamic_update_slice->operand(2);
- // DynamicUpdateSlice on a scalar just passes through the update argument.
- if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
- return ReplaceInstruction(dynamic_update_slice, update);
- }
- // DynamicUpdateSlice where operand and update have the same size and
- // start_indices are all zero is simply equal to update.
- //
- // (We require start_indices to be all zero because we want this optimization
- // not to affect the visible behavior of this op even when the indices are out
- // of range. Currently dynamic-update-slice wraps out-of-range indices, so
- // we can only remove the op if its indices never wrap.)
- if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) {
+ // DynamicUpdateSlice where operand and update have the same size is simply
+ // equal to update.
+ if (SameShape(dynamic_update_slice, update)) {
return ReplaceInstruction(dynamic_update_slice, update);
}
@@ -1889,6 +1912,26 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
new_reduce_dimensions, function));
}
}
+ // Convert Reduce(concat({a,b,...})) to
+ // map(reduce(a),map(reduce(b),...,))
+ //
+ // This should make fusion easier or use less memory bandwidth in the unfused
+ // case.
+ if (arg->opcode() == HloOpcode::kConcatenate &&
+ c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) {
+ HloInstruction* old_reduce = nullptr;
+ for (HloInstruction* operand : arg->operands()) {
+ HloInstruction* new_reduce = computation_->AddInstruction(
+ HloInstruction::CreateReduce(reduce->shape(), operand, init_value,
+ reduce->dimensions(), function));
+ if (old_reduce != nullptr) {
+ new_reduce = computation_->AddInstruction(HloInstruction::CreateMap(
+ reduce->shape(), {old_reduce, new_reduce}, function));
+ }
+ old_reduce = new_reduce;
+ }
+ return ReplaceInstruction(reduce, old_reduce);
+ }
return Status::OK();
}
@@ -2097,7 +2140,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::Zero(convolution->shape().element_type())
+ LiteralUtil::Zero(convolution->shape().element_type())
.CloneToUnique())),
{}));
}