diff options
author | Tong Shen <endlessroad@google.com> | 2018-10-04 11:24:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 11:28:53 -0700 |
commit | c8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (patch) | |
tree | e9e2ee47e91ecb831faa7666e2213967b650c921 /tensorflow/core/common_runtime | |
parent | 6850dafeeaaa48efa748134688844bd079ef3949 (diff) |
Roll forward change "Skip control flow functionalization if there is no Switch or Merge node.".
PiperOrigin-RevId: 215772272
Diffstat (limited to 'tensorflow/core/common_runtime')
4 files changed, 32 insertions, 19 deletions
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 419867ff58..db137f1a19 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -466,7 +466,7 @@ Graph* GetConstantGraph( bool ReplaceTensorWithConstant( Graph* graph, Device* partition_device, NodeAndOutput tensor, const Tensor& constant, const gtl::FlatSet<Node*>& control_deps, - int64 max_constant_size_in_bytes, + int64 max_constant_size_in_bytes, bool disable_memory_output_type_check, const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. @@ -535,21 +535,23 @@ bool ReplaceTensorWithConstant( if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; } - if (partition_device && device_type != DEVICE_CPU) { - MemoryType original_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, - &original_output_memory_type) - .ok()) { - return false; - } - MemoryType const_output_memory_type; - if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, - &const_output_memory_type) - .ok()) { - return false; - } - if (original_output_memory_type != const_output_memory_type) { - return false; + if (!disable_memory_output_type_check) { + if (partition_device && device_type != DEVICE_CPU) { + MemoryType original_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, + &original_output_memory_type) + .ok()) { + return false; + } + MemoryType const_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, + &const_output_memory_type) + .ok()) { + return false; + } + if (original_output_memory_type != const_output_memory_type) { + return false; + } } } for (auto edge : edges_to_remove) { @@ -658,7 +660,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts, constant_control_deps[tensors_to_replace[c].first]; if (ReplaceTensorWithConstant( graph, partition_device, tensors_to_replace[c], outputs[c], - control_deps, opts.max_constant_size_in_bytes, generate_new_name)) { + control_deps, opts.max_constant_size_in_bytes, + opts.disable_memory_output_type_check, generate_new_name)) { ++num_nodes_replaced; } } diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index a9a84f761b..4c71b7bd27 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -45,6 +45,10 @@ struct ConstantFoldingOptions { // optimization. int64 max_constant_size_in_bytes = 10 * 1024 * 1024; + // If disable_memory_output_type_check is true, we will disable output memory + // type check for constant node replacement. + bool disable_memory_output_type_check = false; + // A generator for the name suffix of constant folded nodes. A // default id generator that monotonically increases is used if nullptr is // passed. diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 37a979a8f1..91194bc86f 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -39,7 +39,8 @@ void GraphOptimizer::Optimize( const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn, - const std::function<bool(const Node*)>& cf_consider_fn) { + const std::function<bool(const Node*)>& cf_consider_fn, + bool cf_disable_memory_output_type_check) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -64,6 +65,8 @@ void GraphOptimizer::Optimize( ConstantFoldingOptions cf_opts; cf_opts.shape_map = shape_map; cf_opts.consider = cf_consider_fn; + cf_opts.disable_memory_output_type_check = + cf_disable_memory_output_type_check; if (opts_.max_folded_constant_in_bytes() > 0) { cf_opts.max_constant_size_in_bytes = opts_.max_folded_constant_in_bytes(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 789cc56942..8954e9612d 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -47,13 +47,16 @@ class GraphOptimizer { // returns true will be considered for CSE. // If cf_consider_fn is not null then only nodes for which cf_consider_fn // returns true will be considered for CF. + // If cf_disable_memory_output_type_check is true, CF will discard output + // memory type check for constant node replacement. void Optimize( FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, const std::function<bool(const Node*)>& cse_consider_fn = nullptr, - const std::function<bool(const Node*)>& cf_consider_fn = nullptr); + const std::function<bool(const Node*)>& cf_consider_fn = nullptr, + bool cf_disable_memory_output_type_check = false); const OptimizerOptions& options() { return opts_; } |