aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-10-04 11:24:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 11:28:53 -0700
commitc8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (patch)
treee9e2ee47e91ecb831faa7666e2213967b650c921 /tensorflow
parent6850dafeeaaa48efa748134688844bd079ef3949 (diff)
Roll forward change "Skip control flow functionalization if there is no Switch or Merge node.".
PiperOrigin-RevId: 215772272
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc129
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc37
-rw-r--r--tensorflow/core/common_runtime/constant_folding.h4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc5
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
5 files changed, 122 insertions, 58 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 36c6f5d316..28e09d7b79 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -79,7 +79,10 @@ Status FunctionalizeControlFlowForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
- std::map<string, string>* canonicalized_name_to_new_name) {
+ std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
+ bool* modified) {
+ *modified = false;
+
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
@@ -92,6 +95,19 @@ Status FunctionalizeControlFlowForFunction(
});
const FunctionBody* body = flr->GetFunctionBody(handle);
+ // Check if the graph has Switch or Merge node before optimizing the graph.
+ bool has_switch_or_merge = false;
+ for (Node* n : body->graph->nodes()) {
+ if (n->type_string() == "Switch" || n->type_string() == "Merge") {
+ has_switch_or_merge = true;
+ break;
+ }
+ }
+ // We cannot return here directly if the graph has no Switch/Merge.
+ // It might contain function call nodes, or If/While nodes with Switch/Merge
+ // in function body. We still need to rewrite those functions and modify
+ // corresponding nodes.
+
// Call graph optimizer. The most important optimization we need is constant
// folding, which will replace ops like Shape/BroadcastGradientArgs with
// constant shape input. Without this optimization, those ops might become
@@ -129,6 +145,13 @@ Status FunctionalizeControlFlowForFunction(
absl::StrCat("functionalize_control_flow_after_opt_", func_name),
*optimized_graph, fld);
}
+ // Some inlined functions might have Switch/Merge nodes.
+ for (Node* n : optimized_graph->nodes()) {
+ if (n->type_string() == "Switch" || n->type_string() == "Merge") {
+ has_switch_or_merge = true;
+ break;
+ }
+ }
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
@@ -151,10 +174,15 @@ Status FunctionalizeControlFlowForFunction(
Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
+ bool function_modified;
if (iter != canonicalized_name_to_new_name->end()) {
- // If we already functionalized this function, skip functionalization
- // but still rewrite the node.
- new_name = iter->second;
+ // If we already processed this function, check if it was rewritten. If
+ // the function was rewritten, the entry will be non-empty. Otherwise
+ // the entry will be empty.
+ function_modified = iter->second.has_value();
+ if (function_modified) {
+ new_name = iter->second.value();
+ }
} else {
if (associated_function.type() ==
AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
@@ -166,42 +194,62 @@ Status FunctionalizeControlFlowForFunction(
}
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
name, new_name, associated_function.attrs(), fld, flr,
- canonicalized_name_to_new_name));
- (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ canonicalized_name_to_new_name, &function_modified));
+ if (function_modified) {
+ // If the function was rewritten, add an non-empty entry. So later we
+ // know we have processed this function, and it was rewritten into
+ // another function.
+ (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ } else {
+ // If the function was not rewritten, add an empty entry. So later
+ // we know we have processed this function, and it does not need to be
+ // rewritten.
+ (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
+ }
+ }
+ if (function_modified) {
+ *modified = true;
+
+ // Notice that if "n" is a function call, RewriteAssociatedFunction()
+ // will delete it and create a new node instead, making "n" an invalid
+ // pointer. That's fine because in that case, associated_functions will
+ // only have one member and the loop will only run once.
+ TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+ optimized_graph.get(), n, fld, associated_function, new_name));
}
- // Notice that if "n" is a function call, RewriteAssociatedFunction() will
- // delete it and create a new node instead, making "n" an invalid pointer.
- // That's fine because in that case, associated_functions will only have
- // one member and the loop will only run once.
- TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- optimized_graph.get(), n, fld, associated_function, new_name));
}
}
- // Functionalize the function body.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *optimized_graph, fld);
- }
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *optimized_graph, fld);
+ if (has_switch_or_merge) {
+ *modified = true;
+
+ // Functionalize the function body.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+ *optimized_graph, fld);
+ }
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
+ *optimized_graph, fld);
+ }
}
- FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
- &functionalized_fdef));
- // Add rewritten FunctionDef into library.
- if (func_name == new_func_name) {
- VLOG(2) << "Replacing function " << func_name;
- TF_RETURN_IF_ERROR(
- fld->ReplaceFunction(new_func_name, functionalized_fdef));
- } else {
- VLOG(2) << "Adding function " << new_func_name;
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ if (*modified) {
+ // Add rewritten FunctionDef into library.
+ FunctionDef functionalized_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
+ &functionalized_fdef));
+ if (func_name == new_func_name) {
+ VLOG(2) << "Replacing function " << func_name;
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(new_func_name, functionalized_fdef));
+ } else {
+ VLOG(2) << "Adding function " << new_func_name;
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ }
}
return ret_status;
@@ -227,7 +275,7 @@ Status FunctionalizeControlFlowPass::Run(
{"TPUCompile", "function"},
{"XlaLaunch", "function"},
};
- std::map<string, string> canonicalized_name_to_new_name;
+ std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@@ -242,12 +290,15 @@ Status FunctionalizeControlFlowPass::Run(
<< ". Corresponding function: " << func.name();
string new_func_name = options.flib_def->UniqueFunctionName(
absl::StrCat(func.name(), "_f15n_"));
+ bool modified;
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
func.name(), new_func_name, func.attr(), options.flib_def, flr,
- &canonicalized_name_to_new_name));
- n->ClearAttr(func_attr);
- func.set_name(new_func_name);
- n->AddAttr(func_attr, func);
+ &canonicalized_name_to_new_name, &modified));
+ if (modified) {
+ n->ClearAttr(func_attr);
+ func.set_name(new_func_name);
+ n->AddAttr(func_attr, func);
+ }
}
}
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 419867ff58..db137f1a19 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -466,7 +466,7 @@ Graph* GetConstantGraph(
bool ReplaceTensorWithConstant(
Graph* graph, Device* partition_device, NodeAndOutput tensor,
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
- int64 max_constant_size_in_bytes,
+ int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
@@ -535,21 +535,23 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
- if (partition_device && device_type != DEVICE_CPU) {
- MemoryType original_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
- &original_output_memory_type)
- .ok()) {
- return false;
- }
- MemoryType const_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
- &const_output_memory_type)
- .ok()) {
- return false;
- }
- if (original_output_memory_type != const_output_memory_type) {
- return false;
+ if (!disable_memory_output_type_check) {
+ if (partition_device && device_type != DEVICE_CPU) {
+ MemoryType original_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
+ &original_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ MemoryType const_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
+ &const_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ if (original_output_memory_type != const_output_memory_type) {
+ return false;
+ }
}
}
for (auto edge : edges_to_remove) {
@@ -658,7 +660,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant(
graph, partition_device, tensors_to_replace[c], outputs[c],
- control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
+ control_deps, opts.max_constant_size_in_bytes,
+ opts.disable_memory_output_type_check, generate_new_name)) {
++num_nodes_replaced;
}
}
diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h
index a9a84f761b..4c71b7bd27 100644
--- a/tensorflow/core/common_runtime/constant_folding.h
+++ b/tensorflow/core/common_runtime/constant_folding.h
@@ -45,6 +45,10 @@ struct ConstantFoldingOptions {
// optimization.
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
+ // If disable_memory_output_type_check is true, we will disable output memory
+ // type check for constant node replacement.
+ bool disable_memory_output_type_check = false;
+
// A generator for the name suffix of constant folded nodes. A
// default id generator that monotonically increases is used if nullptr is
// passed.
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 37a979a8f1..91194bc86f 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -39,7 +39,8 @@ void GraphOptimizer::Optimize(
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn,
- const std::function<bool(const Node*)>& cf_consider_fn) {
+ const std::function<bool(const Node*)>& cf_consider_fn,
+ bool cf_disable_memory_output_type_check) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -64,6 +65,8 @@ void GraphOptimizer::Optimize(
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
cf_opts.consider = cf_consider_fn;
+ cf_opts.disable_memory_output_type_check =
+ cf_disable_memory_output_type_check;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 789cc56942..8954e9612d 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -47,13 +47,16 @@ class GraphOptimizer {
// returns true will be considered for CSE.
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
// returns true will be considered for CF.
+ // If cf_disable_memory_output_type_check is true, CF will discard output
+ // memory type check for constant node replacement.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
- const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr,
+ bool cf_disable_memory_output_type_check = false);
const OptimizerOptions& options() { return opts_; }