aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-10-03 18:05:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 18:09:41 -0700
commit18f589350f0cb244e2373480048d17cbacd241e1 (patch)
tree3dc27e2ac9472201b86491cc08df5c03252984bb /tensorflow/compiler
parentf7edc2d308523fa6c2d233c09e3f2da1c98e3dbc (diff)
[XLA] Add a size limit to the constant folder to avoid forming giant constants during compilation.
PiperOrigin-RevId: 215663002
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc20
2 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index f837816cea..538816a353 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -76,6 +76,22 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
continue;
}
+ // Don't constant fold unless it's a net positive or the output is small.
+ int64 elements_in_removed_operands = 0;
+ for (HloInstruction* operand : instruction->operands()) {
+ if (operand->user_count() == 1) {
+ elements_in_removed_operands +=
+ ShapeUtil::ElementsIn(operand->shape());
+ }
+ }
+ int64 elements_in_constant = ShapeUtil::ElementsIn(instruction->shape());
+
+ static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000;
+ if (elements_in_constant > elements_in_removed_operands &&
+ elements_in_constant > kMaximumConstantSizeElements) {
+ continue;
+ }
+
Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
@@ -84,6 +100,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
<< instruction->ToString();
continue;
}
+ VLOG(4) << "Constant folded: " << instruction->ToString();
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(std::move(result))));
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 3e0def5d26..e45f905f71 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -242,5 +242,25 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}
+const char* const kConstantFoldLargePad = R"(
+ HloModule ConstantFoldLargePad
+
+ ENTRY r {
+ a = f32[1,1,1] constant(f32[1,1,1]{{{7}}})
+ b = f32[] constant(42)
+ ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63
+ })";
+
+TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseAndReturnVerifiedModule(kConstantFoldLargePad));
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_FALSE(result);
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Pad(op::Constant(), op::Constant()));
+}
+
} // namespace
} // namespace xla