diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_constant_folding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_constant_folding.cc | 237 |
1 files changed, 32 insertions, 205 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index cb0a99d773..762ceebf39 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -24,230 +24,57 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { -namespace { - -template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type> -static std::unique_ptr<Literal> ConvertIfTypesMatch( - const Literal& src_literal) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< - typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); -} - -template <PrimitiveType primitive_src_type> -static std::unique_ptr<Literal> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -static std::unique_ptr<Literal> ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -} // namespace - -// ConstantFolderVisitor traverses the HLO computation and reduces certain -// constant graph sections, to literals. -class ConstantFolderVisitor : public DfsHloVisitorWithDefault { - public: - // Default visitor action is to do nothing and return OK. - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice<HloInstruction*> operands) override; - - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - - Status HandleReshape(HloInstruction* reshape) override; - - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - - Status HandleTranspose(HloInstruction* transpose) override; - - // Returns whether a constant folding operation has occurred. - const bool changed() const { return changed_; } - - // Runs the visitor on a computation and returns whether any changes were - // performed. - static StatusOr<bool> Run(HloComputation* computation); - - private: - ConstantFolderVisitor() = default; - - // Replaces the existing HLO instruction old_instruction, with a literal, - // and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceWithConstant(HloInstruction* old_instruction, - std::unique_ptr<Literal> literal) { - TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( - old_instruction, HloInstruction::CreateConstant(std::move(literal)))); - changed_ = true; - return Status::OK(); - } - - // Whether any constant folding operations have occurred. - bool changed_ = false; -}; - -StatusOr<bool> ConstantFolderVisitor::Run(HloComputation* computation) { - ConstantFolderVisitor visitor; - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - return visitor.changed(); -} StatusOr<bool> HloConstantFolding::Run(HloModule* module) { + auto evaluator = MakeUnique<HloEvaluator>(); + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); bool changed = false; - for (auto& comp : module->computations()) { - TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get())); - changed = changed || result; - } - XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); - return changed; -} - -Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) { - if (reshape->operand(0)->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN( - auto reshaped_literal, - LiteralUtil::Reshape(reshape->operand(0)->literal(), - AsInt64Slice(reshape->shape().dimensions()))); - return ReplaceWithConstant(reshape, std::move(reshaped_literal)); - } - return Status::OK(); -} -Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) { - if (transpose->operand(0)->opcode() == HloOpcode::kConstant) { - auto transposed_literal = LiteralUtil::Transpose( - transpose->operand(0)->literal(), transpose->dimensions()); - return ReplaceWithConstant(transpose, std::move(transposed_literal)); - } - return Status::OK(); -} + for (auto& computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + // Skip dead code. + if (instruction->user_count() == 0 && + computation->root_instruction() != instruction) { + continue; + } + // Skip Constant and Parameter operation. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant) { + continue; + } + // Skip instructions with non-constant operands. + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + continue; + } -Status ConstantFolderVisitor::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice<HloInstruction*> operands) { - if (operands[0]->opcode() == HloOpcode::kConstant) { - // If all the operands of a concatenate are constant, fold them into a - // single constant tensor. - // The result concatenate dimension is going to be the sum of all the - // concatenate dimensions of the arrays taking part of the operation. - int64 concat_dim = concatenate->dimensions()[0]; - const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); - int64 rank = ShapeUtil::Rank(reference_shape); - std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(), - reference_shape.dimensions().end()); - if (concat_dim < 0) { - concat_dim += rank; - } - for (int64 i = 1; i < operands.size(); ++i) { - const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); - if (operands[i]->opcode() != HloOpcode::kConstant) { - return Status::OK(); + std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction); + // Currently we skip unimplemented operations. + // TODO(b/35975797): Fold constant computations for more operations. + if (result == nullptr) { + VLOG(2) << "Constant folding failed for instruction: " + << instruction->ToString(); + continue; } - // Accumulate the concat dimension from all tensors taking part to the - // operation. - concat_dimensions[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); - } - auto literal = LiteralUtil::CreateFromDimensions( - reference_shape.element_type(), concat_dimensions); - std::vector<int64> source_indices(rank, 0); - std::vector<int64> dest_indices(concat_dimensions.size(), 0); - for (auto operand : operands) { - const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), source_indices, literal.get(), dest_indices, - AsInt64Slice(operand_shape.dimensions()))); - dest_indices[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateConstant(std::move(result)))); + changed = true; } - return ReplaceWithConstant(concatenate, std::move(literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Shape& shape = slice->shape(); - auto literal = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - std::vector<int64> dest_indices(slice->slice_starts().size(), 0); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), slice->slice_starts(), literal.get(), dest_indices, - AsInt64Slice(shape.dimensions()))); - TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal))); } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Literal& src_literal = operand->literal(); - std::unique_ptr<Literal> new_constant = - ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type()); - return ReplaceWithConstant(convert, std::move(new_constant)); - } - return Status::OK(); + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); + return changed; } } // namespace xla |