diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 938 |
1 files changed, 938 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc new file mode 100644 index 0000000000..fe892e872f --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -0,0 +1,938 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" + +#include <algorithm> +#include <memory> +#include <numeric> +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Returns whether operand is a literal with the given value. +bool IsLiteralWithValue(const HloInstruction* operand, int value) { + return operand->opcode() == HloOpcode::kConstant && + LiteralUtil::IsAll(operand->literal(), value); +} + +// Returns whether the given transpose produces a result which is bit-wise +// identical to its operand and thus may be replaced with a bitcast. +bool TransposeIsBitcast( + const HloInstruction* transpose, + const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + CHECK_EQ(HloOpcode::kTranspose, transpose->opcode()); + const HloInstruction* operand = transpose->operand(0); + + // Can't insert bitcasts if the compiler used a memory layout which isn't + // compatible. + if (!valid_bitcast_callback(operand->shape(), transpose->shape())) { + return false; + } + + return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(), + transpose->dimensions()); +} + +// Returns true if the given reshape produces a result which is bit-wise +// identical to its operand and thus may be replaced with a bitcast. +// +// This function is conservative -- even if this function returns false, the +// reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. +bool ReshapeIsBitcast( + const HloInstruction* reshape, + const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + CHECK_EQ(HloOpcode::kReshape, reshape->opcode()); + + const HloInstruction* operand = reshape->operand(0); + // Can't insert bitcasts if the compiler used a memory layout which isn't + // compatible. + if (!valid_bitcast_callback(operand->shape(), reshape->shape())) { + return false; + } + + return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()); +} +} // namespace + +// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain +// algebraic expressions to simplified forms. Note: This only supports +// simplifications that simply look at the operands of an instruction. For the +// more general case a worklist based approach would be needed. +class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleBroadcast(HloInstruction* broadcast) override; + + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override; + + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override; + + Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override; + + Status HandleLog(HloInstruction* log, HloInstruction* operand) override; + + Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandlePad(HloInstruction* pad) override; + + Status HandlePower(HloInstruction* power, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice<int64> dimensions, + HloComputation* function) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, + HloInstruction* rhs) override; + + Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, + HloInstruction* rhs) override; + + // Returns whether algebraic simplification has occurred. + const bool changed() const { return changed_; } + + // Runs the visitor on a computation. + static bool Run( + HloComputation* computation, bool is_layout_sensitive, + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback); + + private: + explicit AlgebraicSimplifierVisitor( + HloComputation* computation, bool is_layout_sensitive, + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) + : computation_(computation), + is_layout_sensitive_(is_layout_sensitive), + valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + + // Convenience method for replacing an instruction with a bitcast. + void ReplaceWithBitcast(HloInstruction* instruction); + + // Replace old instruction with new instruction if old and new instructions + // have the same shape. Updates uses and root instruction. Returns whether a + // replacement was made. + bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction, + HloInstruction* new_instruction); + + // Returns whether the shape of the output of the given instructions are the + // same for the purposes of simplification. If is_layout_sensitive_ is true, + // then this tests shape equality including layout (ShapeUtil::Equal). If + // is_layout_sensitive_ is false, then the tests shape compatibility + // (ShapeUtil::Compatible). + bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; + + // Returns whether it was possible to transform `root` to a clamp instruction. + // With min a minimum instruction, max a maximum instruction, min_operand a + // operand of min and max_operand a operand of max. + // Precondition: root is either a minimum or a maximum. + bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min, + HloInstruction* min_operand, + HloInstruction* operand, HloInstruction* max, + HloInstruction* max_operand); + + // Current HloComputation instance the AlgebraicSimplifierVisitor is + // traversing. + HloComputation* computation_; + + // Whether algebraic simplification has occurred. + bool changed_ = false; + + // Whether layout is considered during transformation. + bool is_layout_sensitive_; + + // Callback used to determine if a bitcast is valid. + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; +}; + +bool AlgebraicSimplifierVisitor::Run( + HloComputation* computation, bool is_layout_sensitive, + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) { + AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive, + std::move(valid_bitcast_callback)); + TF_CHECK_OK(computation->root_instruction()->Accept(&visitor)); + return visitor.changed_; +} + +bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, + const HloInstruction* rhs) const { + if (is_layout_sensitive_) { + return ShapeUtil::Equal(lhs->shape(), rhs->shape()); + } else { + return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); + } +} + +void AlgebraicSimplifierVisitor::ReplaceWithBitcast( + HloInstruction* instruction) { + CHECK_EQ(1, instruction->operand_count()); + CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), + ShapeUtil::ElementsIn(instruction->operand(0)->shape())); + CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape())); + + auto bitcast = computation_->AddInstruction( + HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, + instruction->mutable_operand(0))); + computation_->ReplaceInstruction(instruction, bitcast); + changed_ = true; +} + +bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( + HloInstruction* old_instruction, HloInstruction* new_instruction) { + if (!SameShape(old_instruction, new_instruction)) { + return false; + } + computation_->ReplaceInstruction(old_instruction, new_instruction); + changed_ = true; + return true; +} + +Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, + HloInstruction* lhs, + HloInstruction* rhs) { + // A + 0 => A + VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); + if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { + return Status::OK(); + } + // 0 + A => A + VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); + if (IsLiteralWithValue(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, + HloInstruction* operand) { + // All copies can be eliminated (assuming layout constraints are satisified). + ReplaceInstructionIfSameShape(copy, operand); + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, + HloInstruction* lhs, + HloInstruction* rhs) { + // A - 0 => A + VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); + if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, + HloInstruction* lhs, + HloInstruction* rhs) { + // A/1 => A + VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); + if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { + return Status::OK(); + } + + // exp(A)/exp(B) => exp(A-B) + if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); + HloInstruction* subtract = + computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), + rhs->mutable_operand(0))); + computation_->ReplaceWithNewInstruction( + divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, + subtract)); + changed_ = true; + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, + HloInstruction* lhs, + HloInstruction* rhs) { + // A*1 => A + VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); + if (IsLiteralWithValue(rhs, 1) && + ReplaceInstructionIfSameShape(multiply, lhs)) { + return Status::OK(); + } + // 1*A => A + VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); + if (IsLiteralWithValue(lhs, 1) && + ReplaceInstructionIfSameShape(multiply, rhs)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log, + HloInstruction* operand) { + // ln(exp(A)) => A + VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); + if (operand->opcode() == HloOpcode::kExp && + ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleGetTupleElement( + HloInstruction* get_tuple_element, HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kTuple) { + // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i + VLOG(10) << "trying transform " + << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: " + << get_tuple_element->ToString(); + if (ReplaceInstructionIfSameShape( + get_tuple_element, + operand->mutable_operand(get_tuple_element->tuple_index()))) { + return Status::OK(); + } + } + return Status::OK(); +} + +namespace { + +// Return whether the given reshape instruction leaves the dimensions at the +// given input indices unmodified, and returns their output indices. +// +// Example: +// input_dim_indices = {2, 3} +// input shape = T[a, b, x, y, cd] +// output shape = T[ab, x, 1, y, c, d] +// return value = {1, 3} +// +// Precondition: input_dim_indices is sorted. +std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified( + const HloInstruction* hlo, + tensorflow::gtl::ArraySlice<int64> input_dim_indices) { + CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); + CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); + + std::vector<int64> output_dim_indices; + std::vector<std::pair<int64, int64>> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(), + hlo->shape()); + size_t i = 0; // index to unmodified_dims + for (int64 input_dim_index : input_dim_indices) { + // Search unmodified_dims for input_dim_index. We can search from the last + // matching position because input_dim_indices is guaranteed to be sorted. + while (i < unmodified_dims.size() && + unmodified_dims[i].first < input_dim_index) { + ++i; + } + if (i >= unmodified_dims.size() || + unmodified_dims[i].first != input_dim_index) { + return std::make_pair(false, std::vector<int64>()); + } + output_dim_indices.push_back(unmodified_dims[i].second); + } + return std::make_pair(true, output_dim_indices); +} + +// Returns true if the output of "instruction" is a permutation of the elements +// of "operand". Precondition: "operand" is an operand of "instruction". +bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, + HloInstruction* operand) { + DCHECK(!instruction->OperandIndices(operand).empty()); + switch (instruction->opcode()) { + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + return true; + default: + return false; + } +} + +// Returns true if the output of "instruction" is a subset of the elements of +// "operand". Precondition: "operand" is an operand of "instruction". +bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, + HloInstruction* operand) { + std::vector<int64> operand_indices = instruction->OperandIndices(operand); + CHECK(!operand_indices.empty()); + if (operand_indices.size() != 1) { + return false; + } + int64 operand_index = operand_indices[0]; + switch (instruction->opcode()) { + case HloOpcode::kSlice: + CHECK_EQ(0, operand_index); + return true; + case HloOpcode::kDynamicSlice: + return operand_index == 0; + default: + return false; + } +} + +} // namespace + +Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { + auto operand = broadcast->mutable_operand(0); + // A degenerate broadcast of a reshape that does not change the number of + // elements can be replaced by a reshape. + if (std::is_sorted(broadcast->dimensions().begin(), + broadcast->dimensions().end()) && + ShapeUtil::ElementsIn(broadcast->shape()) == + ShapeUtil::ElementsIn(operand->shape())) { + VLOG(10) << "transform broadcast(X) -> reshape(X) where " + "n(broadcast(X)) == n(X)"; + computation_->ReplaceWithNewInstruction( + broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); + changed_ = true; + return Status::OK(); + } + + // A broadcast of a reshape which merely inserts 1-sized dimensions can elide + // its operand. + { + bool merely_inserts_or_deletes_1_sized_dimensions; + std::vector<int64> inserted_indices, deleted_indices; + std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices, + inserted_indices) = + operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); + if (merely_inserts_or_deletes_1_sized_dimensions && + deleted_indices.empty()) { + std::reverse(inserted_indices.begin(), inserted_indices.end()); + auto dims = broadcast->dimensions(); + for (auto inserted_index : inserted_indices) { + dims.erase(dims.begin() + inserted_index); + } + computation_->ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast(broadcast->shape(), + operand->mutable_operand(0), dims)); + changed_ = true; + return Status::OK(); + } + } + + // A scalar broadcast feeding an instruction which only permutes (reshape, + // transpose, sort, reverse) or selects a subset of operand elements (slice, + // dynamic slice) can be replaced with a broadcast directly to the output + // shape of the instruction. + if (ShapeUtil::IsScalar(operand->shape())) { + for (HloInstruction* user : broadcast->users()) { + if (OutputIsPermutationOfOperandElements(user, broadcast) || + OutputIsSubsetOfOperandElements(user, broadcast)) { + HloInstruction* new_broadcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(user->shape(), operand, {})); + // Use ReplaceUsesOfInstruction instead of ReplaceWithNewInstruction + // because we are replacing an instruction other than the visited + // instruction. + computation_->ReplaceUsesOfInstruction(user, new_broadcast); + changed_ = true; + return Status::OK(); + } + } + } + return Status::OK(); +} + +template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type> +static std::unique_ptr<HloInstruction> ConvertIfTypesMatch( + const Literal& src_literal) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + + return HloInstruction::CreateConstant( + LiteralUtil::Convert<typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal)); +} + +template <PrimitiveType primitive_src_type> +static std::unique_ptr<HloInstruction> ConvertIfDestTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +static std::unique_ptr<HloInstruction> ConvertIfSrcTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (src_literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +// A conversion to the same element type as the operand is a nop and can be +// removed. A conversion of a constant can be simplified by making a new +// constant. +Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, + HloInstruction* operand) { + PrimitiveType src_type = operand->shape().element_type(); + PrimitiveType dest_type = convert->shape().element_type(); + if (src_type == dest_type) { + computation_->ReplaceInstruction(convert, operand); + changed_ = true; + return Status::OK(); + } + if (operand->opcode() == HloOpcode::kConstant) { + const Literal& src_literal = operand->literal(); + std::unique_ptr<HloInstruction> new_constant = + ConvertIfSrcTypeMatches(src_literal, dest_type); + computation_->ReplaceWithNewInstruction(convert, std::move(new_constant)); + changed_ = true; + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { + // The pad instruction does nothing if the output shape is the same as the + // input shape, i.e, all paddings are zero. + ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0)); + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, + HloInstruction* lhs, + HloInstruction* rhs) { + VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); + if (IsLiteralWithValue(rhs, 0)) { + auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( + LiteralUtil::One(power->shape().element_type()))); + std::unique_ptr<HloInstruction> ones; + if (ShapeUtil::IsScalar(power->shape())) { + ones = std::move(one); + } else { + ones = HloInstruction::CreateBroadcast( + power->shape(), computation_->AddInstruction(std::move(one)), {}); + } + computation_->ReplaceWithNewInstruction(power, std::move(ones)); + changed_ = true; + return Status::OK(); + } + + VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); + if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { + return Status::OK(); + } + + VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); + if (IsLiteralWithValue(rhs, 2)) { + computation_->ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), + HloOpcode::kMultiply, lhs, lhs)); + changed_ = true; + return Status::OK(); + } + + VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); + if (IsLiteralWithValue(rhs, -1)) { + auto* one = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( + LiteralUtil::One(rhs->shape().element_type())))); + computation_->ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, + one, lhs)); + changed_ = true; + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { + auto operand = reshape->mutable_operand(0); + + // Delete no-op reshapes, i.e. where shape = operand shape. + if (SameShape(reshape, operand)) { + VLOG(10) << "deleting no-op reshape"; + computation_->ReplaceInstruction(reshape, operand); + changed_ = true; + return Status::OK(); + } + + // Merge reshapes. + if (HloOpcode::kReshape == operand->opcode()) { + computation_->ReplaceWithNewInstruction( + reshape, HloInstruction::CreateReshape(reshape->shape(), + operand->mutable_operand(0))); + changed_ = true; + return Status::OK(); + } + + if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { + auto opt_dims = ReshapeLeavesDimensionsUnmodified( + reshape, reshape->operand(0)->dimensions()); + if (opt_dims.first) { + computation_->ReplaceWithNewInstruction( + reshape, + HloInstruction::CreateBroadcast( + reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), + opt_dims.second)); + changed_ = true; + return Status::OK(); + } + } + + // Make this a bitcast if possible. + if (is_layout_sensitive_ && + ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { + ReplaceWithBitcast(reshape); + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + // Delete no-op slices, i.e. where shape = operand shape. + if (ReplaceInstructionIfSameShape(slice, operand)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleReduce( + HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) { + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + auto reshape = computation_->AddInstruction( + HloInstruction::CreateReshape(reduce->shape(), arg)); + computation_->ReplaceWithNewInstruction( + reduce, HloInstruction::CreateMap(reduce->shape(), + {reshape, init_value}, function)); + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { + auto operand = transpose->mutable_operand(0); + + if (std::is_sorted(transpose->dimensions().begin(), + transpose->dimensions().end())) { + VLOG(10) << "deleting no-op transpose"; + computation_->ReplaceInstruction(transpose, operand); + changed_ = true; + return Status::OK(); + } + + if (HloOpcode::kTranspose == operand->opcode()) { + computation_->ReplaceWithNewInstruction( + transpose, HloInstruction::CreateTranspose( + transpose->shape(), operand->mutable_operand(0), + ComposePermutations(operand->dimensions(), + transpose->dimensions()))); + changed_ = true; + return Status::OK(); + } + + if (is_layout_sensitive_ && + TransposeIsBitcast(transpose, valid_bitcast_callback_)) { + ReplaceWithBitcast(transpose); + return Status::OK(); + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, + const Window& window) { + // HandleConvolution tries to replace a convolution with a DOT instruction. + // + // Only add when bitcasts can be used: + // - if bitcasts are not supported, then reshapes could be used but will + // end up with another copy. + // - if bitcasts are supported, the simplifier will be called again with + // bitcasts_ == true. + + // TODO(cwhipkey): b/31337498, make this layout insensitive. + if (!is_layout_sensitive_) return Status::OK(); + + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + const Shape& input_shape = lhs->shape(); + const Shape& filter_shape = rhs->shape(); + const Shape& convolution_shape = convolution->shape(); + TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape)); + + // Require 1x1 filter in the spatial dimensions (so no need to extract image + // patches). + if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(0)) != 1 || + filter_shape.dimensions(dnums.kernel_spatial_dimensions(1)) != 1) { + return Status::OK(); + } + + // Stride ignores part of the output, which matrix multiplication does not do, + // so require no stride. Padding and base (lhs) dilation both implicitly + // extend the data, which matrix multiplication also does not do, so require + // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect + // for a 1x1 window, so window dilation is no problem. + if (window_util::HasStride(window) || window_util::HasPadding(window) || + window_util::HasBaseDilation(window)) { + return Status::OK(); + } + + // Also, the shapes must align for a rowmajor matmul: + // - the input and output have the same layout. + // - for input/output, the channel dimension must be the most minor. Other + // spatial dims can be in any order. + // - for filters, the input channel dimension must be more major than the + // output channel dimension. The width+height don't matter because + // they are 1. + // + // These constraints are harsh. If the channel dimension is the most major + // and/or the layout of input/output feature dimensions are reversed, we can + // still convert Conv into more efficient Matmul with operand transposition + // (such as the transposition flags in cuBLAS SGEMM). + if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || + input_shape.layout().minor_to_major(0) != dnums.feature_dimension() || + // The input feature dimension should come later in the minor-to-major + // order. + (PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()), + dnums.kernel_input_feature_dimension()) < + PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()), + dnums.kernel_output_feature_dimension()))) { + return Status::OK(); + } + + auto add_bitcast = [&](Shape shape, HloInstruction* operand) { + std::vector<int64> dims(operand->shape().dimensions_size()); + std::iota(dims.begin(), dims.end(), 0); + return computation_->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand)); + }; + + // Replace it with a dot, with bitcasts around it to get the right shape. + const int64 input_channels = + input_shape.dimensions(dnums.feature_dimension()); + const int64 output_channels = + filter_shape.dimensions(dnums.kernel_output_feature_dimension()); + + // Computes the product of the non-feature dimensions. + int64 conv_width = 1; + for (int i = 0; i < input_shape.dimensions_size(); ++i) { + if (i != dnums.feature_dimension()) { + conv_width *= input_shape.dimensions(i); + } + } + + // We already checked feature_dimension is most minor, so data in input_shape + // and row-major {conv_width,input_channels} are bitwise identical. + const Shape new_input_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + input_shape.element_type(), {conv_width, input_channels}); + // We already checked input_feature_dimension is more major than + // output_feature_dimension, so data in filter_shape and row-major + // {input_channels,output_channels} are bitwise identical. + const Shape new_filter_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + filter_shape.element_type(), {input_channels, output_channels}); + const Shape dot_output_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + convolution_shape.element_type(), {conv_width, output_channels}); + + // We cannot insert bitcasts if the layouts will not be compatible. + // TODO(b/33178038): Consider inserting a transpose if a bitcast would be + // invalid. + if (!valid_bitcast_callback_(lhs->shape(), input_shape) || + !valid_bitcast_callback_(rhs->shape(), new_filter_shape) || + !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + return Status::OK(); + } + + auto new_lhs = add_bitcast(new_input_shape, lhs); + auto new_rhs = add_bitcast(new_filter_shape, rhs); + auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( + dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + computation_->ReplaceInstruction(convolution, + add_bitcast(convolution_shape, dot)); + changed_ = true; + return Status::OK(); +} + +bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( + HloInstruction* root, HloInstruction* min, HloInstruction* min_operand, + HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) { + // Ensure shapes of min and max operand are equal to match current shape + // inference. + if (!SameShape(min_operand, max_operand)) { + return false; + } + + auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, + max_operand, operand, min_operand); + computation_->ReplaceWithNewInstruction(root, std::move(clamp)); + changed_ = true; + return true; +} + +Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, + HloInstruction* lhs, + HloInstruction* rhs) { + // Match the following tree: + // min_operand operand + // \ / + // max_operand min + // \ / + // max + // where max_operand and min_operand are scalar constants. + { + HloInstruction* min; + HloInstruction* max_operand; + HloInstruction* min_operand; + HloInstruction* operand; + + if (hlo_query::MatchBinaryInstructionOperandOpcode( + HloOpcode::kMinimum, maximum, + /*matching_operand=*/&min, + /*other_operand=*/&max_operand) && + hlo_query::MatchBinaryInstructionOperand( + hlo_query::IsScalarConstant, min, + /*matching_operand=*/&min_operand, + /*other_operand=*/&operand) && + TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum, + max_operand)) { + return Status::OK(); + } + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, + HloInstruction* lhs, + HloInstruction* rhs) { + // Match the following tree: + // max_operand operand + // \ / + // min_operand max + // \ / + // min + // where max_operand and min_operand are scalar constants. + { + HloInstruction* max; + HloInstruction* max_operand; + HloInstruction* min_operand; + HloInstruction* operand; + + if (hlo_query::MatchBinaryInstructionOperandOpcode( + HloOpcode::kMaximum, minimum, + /*matching_operand=*/&max, + /*other_operand=*/&min_operand) && + hlo_query::MatchBinaryInstructionOperand( + hlo_query::IsScalarConstant, max, + /*matching_operand=*/&max_operand, + /*other_operand=*/&operand) && + TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max, + max_operand)) { + return Status::OK(); + } + } + + return Status::OK(); +} + +StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { + return std::any_of( + module->computations().begin(), module->computations().end(), + [=](const std::unique_ptr<HloComputation>& computation) { + return AlgebraicSimplifierVisitor::Run( + computation.get(), is_layout_sensitive_, valid_bitcast_callback_); + }); +} + +} // namespace xla |