aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-10-04 11:24:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 11:28:53 -0700
commitc8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (patch)
treee9e2ee47e91ecb831faa7666e2213967b650c921 /tensorflow/core
parent6850dafeeaaa48efa748134688844bd079ef3949 (diff)
Roll forward change "Skip control flow functionalization if there is no Switch or Merge node.".
PiperOrigin-RevId: 215772272
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc37
-rw-r--r--tensorflow/core/common_runtime/constant_folding.h4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc5
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
4 files changed, 32 insertions, 19 deletions
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 419867ff58..db137f1a19 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -466,7 +466,7 @@ Graph* GetConstantGraph(
bool ReplaceTensorWithConstant(
Graph* graph, Device* partition_device, NodeAndOutput tensor,
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
- int64 max_constant_size_in_bytes,
+ int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
@@ -535,21 +535,23 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
- if (partition_device && device_type != DEVICE_CPU) {
- MemoryType original_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
- &original_output_memory_type)
- .ok()) {
- return false;
- }
- MemoryType const_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
- &const_output_memory_type)
- .ok()) {
- return false;
- }
- if (original_output_memory_type != const_output_memory_type) {
- return false;
+ if (!disable_memory_output_type_check) {
+ if (partition_device && device_type != DEVICE_CPU) {
+ MemoryType original_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
+ &original_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ MemoryType const_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
+ &const_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ if (original_output_memory_type != const_output_memory_type) {
+ return false;
+ }
}
}
for (auto edge : edges_to_remove) {
@@ -658,7 +660,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant(
graph, partition_device, tensors_to_replace[c], outputs[c],
- control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
+ control_deps, opts.max_constant_size_in_bytes,
+ opts.disable_memory_output_type_check, generate_new_name)) {
++num_nodes_replaced;
}
}
diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h
index a9a84f761b..4c71b7bd27 100644
--- a/tensorflow/core/common_runtime/constant_folding.h
+++ b/tensorflow/core/common_runtime/constant_folding.h
@@ -45,6 +45,10 @@ struct ConstantFoldingOptions {
// optimization.
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
+ // If disable_memory_output_type_check is true, we will disable output memory
+ // type check for constant node replacement.
+ bool disable_memory_output_type_check = false;
+
// A generator for the name suffix of constant folded nodes. A
// default id generator that monotonically increases is used if nullptr is
// passed.
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 37a979a8f1..91194bc86f 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -39,7 +39,8 @@ void GraphOptimizer::Optimize(
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn,
- const std::function<bool(const Node*)>& cf_consider_fn) {
+ const std::function<bool(const Node*)>& cf_consider_fn,
+ bool cf_disable_memory_output_type_check) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -64,6 +65,8 @@ void GraphOptimizer::Optimize(
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
cf_opts.consider = cf_consider_fn;
+ cf_opts.disable_memory_output_type_check =
+ cf_disable_memory_output_type_check;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 789cc56942..8954e9612d 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -47,13 +47,16 @@ class GraphOptimizer {
// returns true will be considered for CSE.
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
// returns true will be considered for CF.
+ // If cf_disable_memory_output_type_check is true, CF will discard output
+ // memory type check for constant node replacement.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
- const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr,
+ bool cf_disable_memory_output_type_check = false);
const OptimizerOptions& options() { return opts_; }