aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/graph_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/graph_optimizer.cc')
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 96ecfb41d4..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -38,7 +38,8 @@ void GraphOptimizer::Optimize(
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn) {
+ const std::function<bool(const Node*)>& cse_consider_fn,
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -62,6 +63,7 @@ void GraphOptimizer::Optimize(
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
+ cf_opts.consider = cf_consider_fn;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();