diff options
author | 2018-04-03 17:08:57 -0700 | |
---|---|---|
committer | 2018-04-03 17:14:05 -0700 | |
commit | 467f195a2dd87257e3719576637774ebcf7a4590 (patch) | |
tree | 6d33032e6c4a628022dd00f2d2d56e5362771720 /tensorflow/tools/graph_transforms | |
parent | df2229540e9a1607193dcb8c83d5f3d7cf5d1a56 (diff) |
Add max_constant_size_in_bytes parameter for ConstantFolding transform that sets the maximum size of each created constant.
PiperOrigin-RevId: 191523208
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_constants_lib.cc | 4 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_constants_test.cc | 44 |
2 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 250f54e20f..85660f94a8 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -283,6 +283,10 @@ Status FoldConstants(const GraphDef& input_graph_def, }; } + TF_RETURN_IF_ERROR(context.GetOneInt64Parameter( + "max_constant_size_in_bytes", cf_opts.max_constant_size_in_bytes, + &cf_opts.max_constant_size_in_bytes)); + // Constant folding. bool was_mutated; TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr, diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index 41106de008..6bfdfe43f5 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -370,6 +370,46 @@ class ConstantFoldingTest : public ::testing::Test { EXPECT_EQ(0, node_map.count("b")); EXPECT_EQ(1, node_map.count("c")); } + + void TestMaxConstantSizeInBytes() { + auto root = tensorflow::Scope::NewRootScope(); + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&a_data, 1.0f); + Output a_const = ::tensorflow::ops::Const( + root.WithOpName("a_expect_remains"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&b_data, 1.0f); + Output b_const = ::tensorflow::ops::Const( + root.WithOpName("b_expect_remains"), Input::Initializer(b_data)); + + Output add = ::tensorflow::ops::Add(root.WithOpName("add_expect_remains"), + a_const, b_const); + + Output placeholder = ::tensorflow::ops::Placeholder( + root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + + Output mul = ::tensorflow::ops::Mul( + root.WithOpName("output_expect_remains"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + Tensor placeholder_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&placeholder_tensor, 1.0f); + + // Setting the maximum constant size to 10 bytes should stop the constant + // folding at add(a, b) that would have yielded a constant of + // 100*sizeof(float) bytes. + graph_transforms::TransformFuncContext context; + context.params["max_constant_size_in_bytes"] = {"10"}; + TestConstantFolding(graph_def, + {{"placeholder_expect_remains", placeholder_tensor}}, + {}, {"output_expect_remains"}, context); + } }; TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); } @@ -394,5 +434,9 @@ TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) { TestRemoveUnusedNodesMultipleOutputs(); } +TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) { + TestMaxConstantSizeInBytes(); +} + } // namespace graph_transforms } // namespace tensorflow |