aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-03 17:08:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 17:14:05 -0700
commit467f195a2dd87257e3719576637774ebcf7a4590 (patch)
tree6d33032e6c4a628022dd00f2d2d56e5362771720 /tensorflow/tools/graph_transforms
parentdf2229540e9a1607193dcb8c83d5f3d7cf5d1a56 (diff)
Add max_constant_size_in_bytes parameter for ConstantFolding transform that sets the maximum size of each created constant.
PiperOrigin-RevId: 191523208
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc4
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc44
2 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 250f54e20f..85660f94a8 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -283,6 +283,10 @@ Status FoldConstants(const GraphDef& input_graph_def,
};
}
+ TF_RETURN_IF_ERROR(context.GetOneInt64Parameter(
+ "max_constant_size_in_bytes", cf_opts.max_constant_size_in_bytes,
+ &cf_opts.max_constant_size_in_bytes));
+
// Constant folding.
bool was_mutated;
TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr,
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index 41106de008..6bfdfe43f5 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -370,6 +370,46 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("b"));
EXPECT_EQ(1, node_map.count("c"));
}
+
+ void TestMaxConstantSizeInBytes() {
+ auto root = tensorflow::Scope::NewRootScope();
+
+ const int width = 100;
+
+ Tensor a_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&a_data, 1.0f);
+ Output a_const = ::tensorflow::ops::Const(
+ root.WithOpName("a_expect_remains"), Input::Initializer(a_data));
+
+ Tensor b_data(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&b_data, 1.0f);
+ Output b_const = ::tensorflow::ops::Const(
+ root.WithOpName("b_expect_remains"), Input::Initializer(b_data));
+
+ Output add = ::tensorflow::ops::Add(root.WithOpName("add_expect_remains"),
+ a_const, b_const);
+
+ Output placeholder = ::tensorflow::ops::Placeholder(
+ root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
+
+ Output mul = ::tensorflow::ops::Mul(
+ root.WithOpName("output_expect_remains"), add, placeholder);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
+ test::FillIota<float>(&placeholder_tensor, 1.0f);
+
+ // Setting the maximum constant size to 10 bytes should stop the constant
+ // folding at add(a, b) that would have yielded a constant of
+ // 100*sizeof(float) bytes.
+ graph_transforms::TransformFuncContext context;
+ context.params["max_constant_size_in_bytes"] = {"10"};
+ TestConstantFolding(graph_def,
+ {{"placeholder_expect_remains", placeholder_tensor}},
+ {}, {"output_expect_remains"}, context);
+ }
};
TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); }
@@ -394,5 +434,9 @@ TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
TestRemoveUnusedNodesMultipleOutputs();
}
+TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
+ TestMaxConstantSizeInBytes();
+}
+
} // namespace graph_transforms
} // namespace tensorflow