aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-26 18:03:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 18:13:16 -0700
commit0c573ae93b3013b91d8d2493a1daed56d11ccc98 (patch)
tree5289b1b2960319c2c9e63c1ee522914c9038789d /tensorflow/core/common_runtime
parent5d61748f4e9998c9d2017bd01864b8fcb6d2127a (diff)
Skip SymbolicGradientOp when doing constant folding in control flow functionalization.
If we want to evaluate SymbolicGradient op in constant folding, we need to construct Device object and attach it to FunctionLibraryRuntime. In graph rewriting pass, we do not have Device object created yet; it will only be created in XlaCompiler. PiperOrigin-RevId: 214702943
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
2 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 96ecfb41d4..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -38,7 +38,8 @@ void GraphOptimizer::Optimize(
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn) {
+ const std::function<bool(const Node*)>& cse_consider_fn,
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -62,6 +63,7 @@ void GraphOptimizer::Optimize(
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
+ cf_opts.consider = cf_consider_fn;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 80246281cd..789cc56942 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -45,12 +45,15 @@ class GraphOptimizer {
//
// If cse_consider_fn is not null then only nodes for which cse_consider_fn
// returns true will be considered for CSE.
+ // If cf_consider_fn is not null then only nodes for which cf_consider_fn
+ // returns true will be considered for CF.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn = nullptr);
+ const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }