aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc938
1 files changed, 938 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
new file mode 100644
index 0000000000..fe892e872f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -0,0 +1,938 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+
+#include <algorithm>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// Returns whether operand is a literal with the given value.
+bool IsLiteralWithValue(const HloInstruction* operand, int value) {
+ return operand->opcode() == HloOpcode::kConstant &&
+ LiteralUtil::IsAll(operand->literal(), value);
+}
+
+// Returns whether the given transpose produces a result which is bit-wise
+// identical to its operand and thus may be replaced with a bitcast.
+bool TransposeIsBitcast(
+ const HloInstruction* transpose,
+ const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
+ CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
+ const HloInstruction* operand = transpose->operand(0);
+
+ // Can't insert bitcasts if the compiler used a memory layout which isn't
+ // compatible.
+ if (!valid_bitcast_callback(operand->shape(), transpose->shape())) {
+ return false;
+ }
+
+ return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
+ transpose->dimensions());
+}
+
+// Returns true if the given reshape produces a result which is bit-wise
+// identical to its operand and thus may be replaced with a bitcast.
+//
+// This function is conservative -- even if this function returns false, the
+// reshape may still be a bitcast. For example, a reshape from [28x28] to [784].
+bool ReshapeIsBitcast(
+ const HloInstruction* reshape,
+ const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
+ CHECK_EQ(HloOpcode::kReshape, reshape->opcode());
+
+ const HloInstruction* operand = reshape->operand(0);
+ // Can't insert bitcasts if the compiler used a memory layout which isn't
+ // compatible.
+ if (!valid_bitcast_callback(operand->shape(), reshape->shape())) {
+ return false;
+ }
+
+ return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape());
+}
+} // namespace
+
+// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
+// algebraic expressions to simplified forms. Note: This only supports
+// simplifications that simply look at the operands of an instruction. For the
+// more general case a worklist based approach would be needed.
+class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
+ public:
+ // Default visitor action is to do nothing and return OK.
+ Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
+ return Status::OK();
+ }
+
+ Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+
+ Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+
+ Status HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) override;
+
+ Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
+ HloInstruction* rhs, const Window& window) override;
+
+ Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+
+ Status HandleLog(HloInstruction* log, HloInstruction* operand) override;
+
+ Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandlePad(HloInstruction* pad) override;
+
+ Status HandlePower(HloInstruction* power, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleReshape(HloInstruction* reshape) override;
+
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) override;
+
+ Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
+
+ Status HandleTranspose(HloInstruction* transpose) override;
+
+ Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ // Returns whether algebraic simplification has occurred.
+ const bool changed() const { return changed_; }
+
+ // Runs the visitor on a computation.
+ static bool Run(
+ HloComputation* computation, bool is_layout_sensitive,
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback);
+
+ private:
+ explicit AlgebraicSimplifierVisitor(
+ HloComputation* computation, bool is_layout_sensitive,
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback)
+ : computation_(computation),
+ is_layout_sensitive_(is_layout_sensitive),
+ valid_bitcast_callback_(std::move(valid_bitcast_callback)) {}
+
+ // Convenience method for replacing an instruction with a bitcast.
+ void ReplaceWithBitcast(HloInstruction* instruction);
+
+ // Replace old instruction with new instruction if old and new instructions
+ // have the same shape. Updates uses and root instruction. Returns whether a
+ // replacement was made.
+ bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
+ HloInstruction* new_instruction);
+
+ // Returns whether the shape of the output of the given instructions are the
+ // same for the purposes of simplification. If is_layout_sensitive_ is true,
+ // then this tests shape equality including layout (ShapeUtil::Equal). If
+ // is_layout_sensitive_ is false, then the tests shape compatibility
+ // (ShapeUtil::Compatible).
+ bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
+
+ // Returns whether it was possible to transform `root` to a clamp instruction.
+ // With min a minimum instruction, max a maximum instruction, min_operand a
+ // operand of min and max_operand a operand of max.
+ // Precondition: root is either a minimum or a maximum.
+ bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
+ HloInstruction* min_operand,
+ HloInstruction* operand, HloInstruction* max,
+ HloInstruction* max_operand);
+
+ // Current HloComputation instance the AlgebraicSimplifierVisitor is
+ // traversing.
+ HloComputation* computation_;
+
+ // Whether algebraic simplification has occurred.
+ bool changed_ = false;
+
+ // Whether layout is considered during transformation.
+ bool is_layout_sensitive_;
+
+ // Callback used to determine if a bitcast is valid.
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_;
+};
+
+bool AlgebraicSimplifierVisitor::Run(
+ HloComputation* computation, bool is_layout_sensitive,
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) {
+ AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive,
+ std::move(valid_bitcast_callback));
+ TF_CHECK_OK(computation->root_instruction()->Accept(&visitor));
+ return visitor.changed_;
+}
+
+bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
+ const HloInstruction* rhs) const {
+ if (is_layout_sensitive_) {
+ return ShapeUtil::Equal(lhs->shape(), rhs->shape());
+ } else {
+ return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
+ }
+}
+
+void AlgebraicSimplifierVisitor::ReplaceWithBitcast(
+ HloInstruction* instruction) {
+ CHECK_EQ(1, instruction->operand_count());
+ CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
+ ShapeUtil::ElementsIn(instruction->operand(0)->shape()));
+ CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
+ ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()));
+
+ auto bitcast = computation_->AddInstruction(
+ HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast,
+ instruction->mutable_operand(0)));
+ computation_->ReplaceInstruction(instruction, bitcast);
+ changed_ = true;
+}
+
+bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
+ HloInstruction* old_instruction, HloInstruction* new_instruction) {
+ if (!SameShape(old_instruction, new_instruction)) {
+ return false;
+ }
+ computation_->ReplaceInstruction(old_instruction, new_instruction);
+ changed_ = true;
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A + 0 => A
+ VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
+ if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
+ return Status::OK();
+ }
+ // 0 + A => A
+ VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
+ if (IsLiteralWithValue(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy,
+ HloInstruction* operand) {
+ // All copies can be eliminated (assuming layout constraints are satisified).
+ ReplaceInstructionIfSameShape(copy, operand);
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A - 0 => A
+ VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
+ if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A/1 => A
+ VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
+ if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) {
+ return Status::OK();
+ }
+
+ // exp(A)/exp(B) => exp(A-B)
+ if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
+ VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
+ HloInstruction* subtract =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0),
+ rhs->mutable_operand(0)));
+ computation_->ReplaceWithNewInstruction(
+ divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp,
+ subtract));
+ changed_ = true;
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A*1 => A
+ VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
+ if (IsLiteralWithValue(rhs, 1) &&
+ ReplaceInstructionIfSameShape(multiply, lhs)) {
+ return Status::OK();
+ }
+ // 1*A => A
+ VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString();
+ if (IsLiteralWithValue(lhs, 1) &&
+ ReplaceInstructionIfSameShape(multiply, rhs)) {
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log,
+ HloInstruction* operand) {
+ // ln(exp(A)) => A
+ VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
+ if (operand->opcode() == HloOpcode::kExp &&
+ ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) {
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
+ HloInstruction* get_tuple_element, HloInstruction* operand) {
+ if (operand->opcode() == HloOpcode::kTuple) {
+ // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
+ VLOG(10) << "trying transform "
+ << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
+ << get_tuple_element->ToString();
+ if (ReplaceInstructionIfSameShape(
+ get_tuple_element,
+ operand->mutable_operand(get_tuple_element->tuple_index()))) {
+ return Status::OK();
+ }
+ }
+ return Status::OK();
+}
+
+namespace {
+
+// Return whether the given reshape instruction leaves the dimensions at the
+// given input indices unmodified, and returns their output indices.
+//
+// Example:
+// input_dim_indices = {2, 3}
+// input shape = T[a, b, x, y, cd]
+// output shape = T[ab, x, 1, y, c, d]
+// return value = {1, 3}
+//
+// Precondition: input_dim_indices is sorted.
+std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
+ const HloInstruction* hlo,
+ tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
+ CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
+ CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
+
+ std::vector<int64> output_dim_indices;
+ std::vector<std::pair<int64, int64>> unmodified_dims =
+ ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
+ hlo->shape());
+ size_t i = 0; // index to unmodified_dims
+ for (int64 input_dim_index : input_dim_indices) {
+ // Search unmodified_dims for input_dim_index. We can search from the last
+ // matching position because input_dim_indices is guaranteed to be sorted.
+ while (i < unmodified_dims.size() &&
+ unmodified_dims[i].first < input_dim_index) {
+ ++i;
+ }
+ if (i >= unmodified_dims.size() ||
+ unmodified_dims[i].first != input_dim_index) {
+ return std::make_pair(false, std::vector<int64>());
+ }
+ output_dim_indices.push_back(unmodified_dims[i].second);
+ }
+ return std::make_pair(true, output_dim_indices);
+}
+
+// Returns true if the output of "instruction" is a permutation of the elements
+// of "operand". Precondition: "operand" is an operand of "instruction".
+bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
+ HloInstruction* operand) {
+ DCHECK(!instruction->OperandIndices(operand).empty());
+ switch (instruction->opcode()) {
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kSort:
+ case HloOpcode::kTranspose:
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Returns true if the output of "instruction" is a subset of the elements of
+// "operand". Precondition: "operand" is an operand of "instruction".
+bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
+ HloInstruction* operand) {
+ std::vector<int64> operand_indices = instruction->OperandIndices(operand);
+ CHECK(!operand_indices.empty());
+ if (operand_indices.size() != 1) {
+ return false;
+ }
+ int64 operand_index = operand_indices[0];
+ switch (instruction->opcode()) {
+ case HloOpcode::kSlice:
+ CHECK_EQ(0, operand_index);
+ return true;
+ case HloOpcode::kDynamicSlice:
+ return operand_index == 0;
+ default:
+ return false;
+ }
+}
+
+} // namespace
+
+Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
+ auto operand = broadcast->mutable_operand(0);
+ // A degenerate broadcast of a reshape that does not change the number of
+ // elements can be replaced by a reshape.
+ if (std::is_sorted(broadcast->dimensions().begin(),
+ broadcast->dimensions().end()) &&
+ ShapeUtil::ElementsIn(broadcast->shape()) ==
+ ShapeUtil::ElementsIn(operand->shape())) {
+ VLOG(10) << "transform broadcast(X) -> reshape(X) where "
+ "n(broadcast(X)) == n(X)";
+ computation_->ReplaceWithNewInstruction(
+ broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ // A broadcast of a reshape which merely inserts 1-sized dimensions can elide
+ // its operand.
+ {
+ bool merely_inserts_or_deletes_1_sized_dimensions;
+ std::vector<int64> inserted_indices, deleted_indices;
+ std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
+ inserted_indices) =
+ operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+ if (merely_inserts_or_deletes_1_sized_dimensions &&
+ deleted_indices.empty()) {
+ std::reverse(inserted_indices.begin(), inserted_indices.end());
+ auto dims = broadcast->dimensions();
+ for (auto inserted_index : inserted_indices) {
+ dims.erase(dims.begin() + inserted_index);
+ }
+ computation_->ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateBroadcast(broadcast->shape(),
+ operand->mutable_operand(0), dims));
+ changed_ = true;
+ return Status::OK();
+ }
+ }
+
+ // A scalar broadcast feeding an instruction which only permutes (reshape,
+ // transpose, sort, reverse) or selects a subset of operand elements (slice,
+ // dynamic slice) can be replaced with a broadcast directly to the output
+ // shape of the instruction.
+ if (ShapeUtil::IsScalar(operand->shape())) {
+ for (HloInstruction* user : broadcast->users()) {
+ if (OutputIsPermutationOfOperandElements(user, broadcast) ||
+ OutputIsSubsetOfOperandElements(user, broadcast)) {
+ HloInstruction* new_broadcast = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(user->shape(), operand, {}));
+ // Use ReplaceUsesOfInstruction instead of ReplaceWithNewInstruction
+ // because we are replacing an instruction other than the visited
+ // instruction.
+ computation_->ReplaceUsesOfInstruction(user, new_broadcast);
+ changed_ = true;
+ return Status::OK();
+ }
+ }
+ }
+ return Status::OK();
+}
+
+template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
+static std::unique_ptr<HloInstruction> ConvertIfTypesMatch(
+ const Literal& src_literal) {
+ CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
+
+ return HloInstruction::CreateConstant(
+ LiteralUtil::Convert<typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal));
+}
+
+template <PrimitiveType primitive_src_type>
+static std::unique_ptr<HloInstruction> ConvertIfDestTypeMatches(
+ const Literal& src_literal, PrimitiveType primitive_dest_type) {
+ switch (primitive_dest_type) {
+#define CONVERT_IF_TYPES_MATCH(type) \
+ case (type): \
+ return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
+ CONVERT_IF_TYPES_MATCH(PRED)
+ CONVERT_IF_TYPES_MATCH(S8)
+ CONVERT_IF_TYPES_MATCH(S32)
+ CONVERT_IF_TYPES_MATCH(S64)
+ CONVERT_IF_TYPES_MATCH(U8)
+ CONVERT_IF_TYPES_MATCH(U32)
+ CONVERT_IF_TYPES_MATCH(U64)
+ CONVERT_IF_TYPES_MATCH(F32)
+ CONVERT_IF_TYPES_MATCH(F64)
+#undef CONVERT_IF_TYPES_MATCH
+ // Other types are not yet supported.
+ default:
+ LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
+ << PrimitiveType_Name(src_literal.shape().element_type());
+ }
+}
+
+static std::unique_ptr<HloInstruction> ConvertIfSrcTypeMatches(
+ const Literal& src_literal, PrimitiveType primitive_dest_type) {
+ switch (src_literal.shape().element_type()) {
+#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
+ case (type): \
+ return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
+ CONVERT_IF_DEST_TYPE_MATCHES(PRED)
+ CONVERT_IF_DEST_TYPE_MATCHES(S8)
+ CONVERT_IF_DEST_TYPE_MATCHES(S32)
+ CONVERT_IF_DEST_TYPE_MATCHES(S64)
+ CONVERT_IF_DEST_TYPE_MATCHES(U8)
+ CONVERT_IF_DEST_TYPE_MATCHES(U32)
+ CONVERT_IF_DEST_TYPE_MATCHES(U64)
+ CONVERT_IF_DEST_TYPE_MATCHES(F32)
+ CONVERT_IF_DEST_TYPE_MATCHES(F64)
+#undef CONVERT_IF_DEST_TYPE_MATCHES
+ // Other types are not yet supported.
+ default:
+ LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
+ << PrimitiveType_Name(src_literal.shape().element_type());
+ }
+}
+
+// A conversion to the same element type as the operand is a nop and can be
+// removed. A conversion of a constant can be simplified by making a new
+// constant.
+Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) {
+ PrimitiveType src_type = operand->shape().element_type();
+ PrimitiveType dest_type = convert->shape().element_type();
+ if (src_type == dest_type) {
+ computation_->ReplaceInstruction(convert, operand);
+ changed_ = true;
+ return Status::OK();
+ }
+ if (operand->opcode() == HloOpcode::kConstant) {
+ const Literal& src_literal = operand->literal();
+ std::unique_ptr<HloInstruction> new_constant =
+ ConvertIfSrcTypeMatches(src_literal, dest_type);
+ computation_->ReplaceWithNewInstruction(convert, std::move(new_constant));
+ changed_ = true;
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
+ // The pad instruction does nothing if the output shape is the same as the
+ // input shape, i.e, all paddings are zero.
+ ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0));
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, 0)) {
+ auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
+ LiteralUtil::One(power->shape().element_type())));
+ std::unique_ptr<HloInstruction> ones;
+ if (ShapeUtil::IsScalar(power->shape())) {
+ ones = std::move(one);
+ } else {
+ ones = HloInstruction::CreateBroadcast(
+ power->shape(), computation_->AddInstruction(std::move(one)), {});
+ }
+ computation_->ReplaceWithNewInstruction(power, std::move(ones));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
+ return Status::OK();
+ }
+
+ VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, 2)) {
+ computation_->ReplaceWithNewInstruction(
+ power, HloInstruction::CreateBinary(power->shape(),
+ HloOpcode::kMultiply, lhs, lhs));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, -1)) {
+ auto* one = computation_->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
+ LiteralUtil::One(rhs->shape().element_type()))));
+ computation_->ReplaceWithNewInstruction(
+ power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
+ one, lhs));
+ changed_ = true;
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
+ auto operand = reshape->mutable_operand(0);
+
+ // Delete no-op reshapes, i.e. where shape = operand shape.
+ if (SameShape(reshape, operand)) {
+ VLOG(10) << "deleting no-op reshape";
+ computation_->ReplaceInstruction(reshape, operand);
+ changed_ = true;
+ return Status::OK();
+ }
+
+ // Merge reshapes.
+ if (HloOpcode::kReshape == operand->opcode()) {
+ computation_->ReplaceWithNewInstruction(
+ reshape, HloInstruction::CreateReshape(reshape->shape(),
+ operand->mutable_operand(0)));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
+ auto opt_dims = ReshapeLeavesDimensionsUnmodified(
+ reshape, reshape->operand(0)->dimensions());
+ if (opt_dims.first) {
+ computation_->ReplaceWithNewInstruction(
+ reshape,
+ HloInstruction::CreateBroadcast(
+ reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
+ opt_dims.second));
+ changed_ = true;
+ return Status::OK();
+ }
+ }
+
+ // Make this a bitcast if possible.
+ if (is_layout_sensitive_ &&
+ ReshapeIsBitcast(reshape, valid_bitcast_callback_)) {
+ ReplaceWithBitcast(reshape);
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
+ HloInstruction* operand) {
+ // Delete no-op slices, i.e. where shape = operand shape.
+ if (ReplaceInstructionIfSameShape(slice, operand)) {
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleReduce(
+ HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
+ if (ShapeUtil::ElementsIn(reduce->shape()) ==
+ ShapeUtil::ElementsIn(arg->shape())) {
+ auto reshape = computation_->AddInstruction(
+ HloInstruction::CreateReshape(reduce->shape(), arg));
+ computation_->ReplaceWithNewInstruction(
+ reduce, HloInstruction::CreateMap(reduce->shape(),
+ {reshape, init_value}, function));
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
+ auto operand = transpose->mutable_operand(0);
+
+ if (std::is_sorted(transpose->dimensions().begin(),
+ transpose->dimensions().end())) {
+ VLOG(10) << "deleting no-op transpose";
+ computation_->ReplaceInstruction(transpose, operand);
+ changed_ = true;
+ return Status::OK();
+ }
+
+ if (HloOpcode::kTranspose == operand->opcode()) {
+ computation_->ReplaceWithNewInstruction(
+ transpose, HloInstruction::CreateTranspose(
+ transpose->shape(), operand->mutable_operand(0),
+ ComposePermutations(operand->dimensions(),
+ transpose->dimensions())));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ if (is_layout_sensitive_ &&
+ TransposeIsBitcast(transpose, valid_bitcast_callback_)) {
+ ReplaceWithBitcast(transpose);
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+ HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window) {
+ // HandleConvolution tries to replace a convolution with a DOT instruction.
+ //
+ // Only add when bitcasts can be used:
+ // - if bitcasts are not supported, then reshapes could be used but will
+ // end up with another copy.
+ // - if bitcasts are supported, the simplifier will be called again with
+ // bitcasts_ == true.
+
+ // TODO(cwhipkey): b/31337498, make this layout insensitive.
+ if (!is_layout_sensitive_) return Status::OK();
+
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+ const Shape& input_shape = lhs->shape();
+ const Shape& filter_shape = rhs->shape();
+ const Shape& convolution_shape = convolution->shape();
+ TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
+ TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
+ TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
+
+ // Require 1x1 filter in the spatial dimensions (so no need to extract image
+ // patches).
+ if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(0)) != 1 ||
+ filter_shape.dimensions(dnums.kernel_spatial_dimensions(1)) != 1) {
+ return Status::OK();
+ }
+
+ // Stride ignores part of the output, which matrix multiplication does not do,
+ // so require no stride. Padding and base (lhs) dilation both implicitly
+ // extend the data, which matrix multiplication also does not do, so require
+ // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
+ // for a 1x1 window, so window dilation is no problem.
+ if (window_util::HasStride(window) || window_util::HasPadding(window) ||
+ window_util::HasBaseDilation(window)) {
+ return Status::OK();
+ }
+
+ // Also, the shapes must align for a rowmajor matmul:
+ // - the input and output have the same layout.
+ // - for input/output, the channel dimension must be the most minor. Other
+ // spatial dims can be in any order.
+ // - for filters, the input channel dimension must be more major than the
+ // output channel dimension. The width+height don't matter because
+ // they are 1.
+ //
+ // These constraints are harsh. If the channel dimension is the most major
+ // and/or the layout of input/output feature dimensions are reversed, we can
+ // still convert Conv into more efficient Matmul with operand transposition
+ // (such as the transposition flags in cuBLAS SGEMM).
+ if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
+ input_shape.layout().minor_to_major(0) != dnums.feature_dimension() ||
+ // The input feature dimension should come later in the minor-to-major
+ // order.
+ (PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
+ dnums.kernel_input_feature_dimension()) <
+ PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
+ dnums.kernel_output_feature_dimension()))) {
+ return Status::OK();
+ }
+
+ auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
+ std::vector<int64> dims(operand->shape().dimensions_size());
+ std::iota(dims.begin(), dims.end(), 0);
+ return computation_->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand));
+ };
+
+ // Replace it with a dot, with bitcasts around it to get the right shape.
+ const int64 input_channels =
+ input_shape.dimensions(dnums.feature_dimension());
+ const int64 output_channels =
+ filter_shape.dimensions(dnums.kernel_output_feature_dimension());
+
+ // Computes the product of the non-feature dimensions.
+ int64 conv_width = 1;
+ for (int i = 0; i < input_shape.dimensions_size(); ++i) {
+ if (i != dnums.feature_dimension()) {
+ conv_width *= input_shape.dimensions(i);
+ }
+ }
+
+ // We already checked feature_dimension is most minor, so data in input_shape
+ // and row-major {conv_width,input_channels} are bitwise identical.
+ const Shape new_input_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ input_shape.element_type(), {conv_width, input_channels});
+ // We already checked input_feature_dimension is more major than
+ // output_feature_dimension, so data in filter_shape and row-major
+ // {input_channels,output_channels} are bitwise identical.
+ const Shape new_filter_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ filter_shape.element_type(), {input_channels, output_channels});
+ const Shape dot_output_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ convolution_shape.element_type(), {conv_width, output_channels});
+
+ // We cannot insert bitcasts if the layouts will not be compatible.
+ // TODO(b/33178038): Consider inserting a transpose if a bitcast would be
+ // invalid.
+ if (!valid_bitcast_callback_(lhs->shape(), input_shape) ||
+ !valid_bitcast_callback_(rhs->shape(), new_filter_shape) ||
+ !valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
+ return Status::OK();
+ }
+
+ auto new_lhs = add_bitcast(new_input_shape, lhs);
+ auto new_rhs = add_bitcast(new_filter_shape, rhs);
+ auto dot = computation_->AddInstruction(HloInstruction::CreateBinary(
+ dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs));
+ computation_->ReplaceInstruction(convolution,
+ add_bitcast(convolution_shape, dot));
+ changed_ = true;
+ return Status::OK();
+}
+
+bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
+ HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
+ HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
+ // Ensure shapes of min and max operand are equal to match current shape
+ // inference.
+ if (!SameShape(min_operand, max_operand)) {
+ return false;
+ }
+
+ auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
+ max_operand, operand, min_operand);
+ computation_->ReplaceWithNewInstruction(root, std::move(clamp));
+ changed_ = true;
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // Match the following tree:
+ // min_operand operand
+ // \ /
+ // max_operand min
+ // \ /
+ // max
+ // where max_operand and min_operand are scalar constants.
+ {
+ HloInstruction* min;
+ HloInstruction* max_operand;
+ HloInstruction* min_operand;
+ HloInstruction* operand;
+
+ if (hlo_query::MatchBinaryInstructionOperandOpcode(
+ HloOpcode::kMinimum, maximum,
+ /*matching_operand=*/&min,
+ /*other_operand=*/&max_operand) &&
+ hlo_query::MatchBinaryInstructionOperand(
+ hlo_query::IsScalarConstant, min,
+ /*matching_operand=*/&min_operand,
+ /*other_operand=*/&operand) &&
+ TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum,
+ max_operand)) {
+ return Status::OK();
+ }
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // Match the following tree:
+ // max_operand operand
+ // \ /
+ // min_operand max
+ // \ /
+ // min
+ // where max_operand and min_operand are scalar constants.
+ {
+ HloInstruction* max;
+ HloInstruction* max_operand;
+ HloInstruction* min_operand;
+ HloInstruction* operand;
+
+ if (hlo_query::MatchBinaryInstructionOperandOpcode(
+ HloOpcode::kMaximum, minimum,
+ /*matching_operand=*/&max,
+ /*other_operand=*/&min_operand) &&
+ hlo_query::MatchBinaryInstructionOperand(
+ hlo_query::IsScalarConstant, max,
+ /*matching_operand=*/&max_operand,
+ /*other_operand=*/&operand) &&
+ TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max,
+ max_operand)) {
+ return Status::OK();
+ }
+ }
+
+ return Status::OK();
+}
+
+StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
+ return std::any_of(
+ module->computations().begin(), module->computations().end(),
+ [=](const std::unique_ptr<HloComputation>& computation) {
+ return AlgebraicSimplifierVisitor::Run(
+ computation.get(), is_layout_sensitive_, valid_bitcast_callback_);
+ });
+}
+
+} // namespace xla