diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-26 18:03:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 18:13:16 -0700 |
commit | 0c573ae93b3013b91d8d2493a1daed56d11ccc98 (patch) | |
tree | 5289b1b2960319c2c9e63c1ee522914c9038789d /tensorflow/core/common_runtime | |
parent | 5d61748f4e9998c9d2017bd01864b8fcb6d2127a (diff) |
Skip SymbolicGradientOp when doing constant folding in control flow functionalization.
If we want to evaluate SymbolicGradient op in constant folding, we need to construct Device object and attach it to FunctionLibraryRuntime. In graph rewriting pass, we do not have Device object created yet; it will only be created in XlaCompiler.
PiperOrigin-RevId: 214702943
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/graph_optimizer.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/graph_optimizer.h | 5 |
2 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 96ecfb41d4..37a979a8f1 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -38,7 +38,8 @@ void GraphOptimizer::Optimize( std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, - const std::function<bool(const Node*)>& cse_consider_fn) { + const std::function<bool(const Node*)>& cse_consider_fn, + const std::function<bool(const Node*)>& cf_consider_fn) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -62,6 +63,7 @@ void GraphOptimizer::Optimize( if (opts_.do_constant_folding()) { ConstantFoldingOptions cf_opts; cf_opts.shape_map = shape_map; + cf_opts.consider = cf_consider_fn; if (opts_.max_folded_constant_in_bytes() > 0) { cf_opts.max_constant_size_in_bytes = opts_.max_folded_constant_in_bytes(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 80246281cd..789cc56942 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -45,12 +45,15 @@ class GraphOptimizer { // // If cse_consider_fn is not null then only nodes for which cse_consider_fn // returns true will be considered for CSE. + // If cf_consider_fn is not null then only nodes for which cf_consider_fn + // returns true will be considered for CF. void Optimize( FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, - const std::function<bool(const Node*)>& cse_consider_fn = nullptr); + const std::function<bool(const Node*)>& cse_consider_fn = nullptr, + const std::function<bool(const Node*)>& cf_consider_fn = nullptr); const OptimizerOptions& options() { return opts_; } |