diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_constant_folding.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_constant_folding_test.cc | 20 |
2 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index f837816cea..538816a353 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,6 +76,22 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { continue; } + // Don't constant fold unless it's a net positive or the output is small. + int64 elements_in_removed_operands = 0; + for (HloInstruction* operand : instruction->operands()) { + if (operand->user_count() == 1) { + elements_in_removed_operands += + ShapeUtil::ElementsIn(operand->shape()); + } + } + int64 elements_in_constant = ShapeUtil::ElementsIn(instruction->shape()); + + static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000; + if (elements_in_constant > elements_in_removed_operands && + elements_in_constant > kMaximumConstantSizeElements) { + continue; + } + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. @@ -84,6 +100,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { << instruction->ToString(); continue; } + VLOG(4) << "Constant folded: " << instruction->ToString(); TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant(std::move(result)))); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 3e0def5d26..e45f905f71 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -242,5 +242,25 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } +const char* const kConstantFoldLargePad = R"( + HloModule ConstantFoldLargePad + + ENTRY r { + a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + b = f32[] constant(42) + ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 + })"; + +TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kConstantFoldLargePad)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Pad(op::Constant(), op::Constant())); +} + } // namespace } // namespace xla |