aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-12 10:03:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 10:07:48 -0700
commit37ddb13ece32500bf87af5d8b8493be1c77781de (patch)
treed3b5daf4c29b5402a0db9e3f316665b5c1d1a9b1
parent26509bf4e202c09da4f0b00d43ebddf87368a0f2 (diff)
Roll forward change "Move control flow functionalization as a graph optimization pass, instead of a step in XlaCompiler.".
PiperOrigin-RevId: 212657932
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc12
-rw-r--r--tensorflow/compiler/tf2xla/BUILD18
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc10
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc147
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.h13
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc25
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc23
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc1
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc102
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h62
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc17
-rw-r--r--tensorflow/core/framework/function.cc11
-rw-r--r--tensorflow/core/framework/function.h4
16 files changed, 435 insertions, 32 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index a989f15a1c..7d5db713f6 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -265,6 +265,7 @@ cc_library(
srcs = ["jit_compilation_pass_registration.cc"],
deps = [
":compilation_passes",
+ "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/core:core_cpu_internal",
],
alwayslink = 1,
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index c37b6112cc..5dcf754969 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -21,6 +21,18 @@ limitations under the License.
namespace tensorflow {
+// PRE_PLACEMENT passes:
+
+// from
+// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
+// FunctionalizeControlFlowPass: 27
+//
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (XlaIf/XlaWhile). Following passes must
+// handle those FunctionDef correctly.
+
+// POST_REWRITE_FOR_EXEC passes:
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index ab289a2b6c..e29a4c0603 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -76,6 +76,7 @@ cc_library(
deps = [
":common",
":dump_graph",
+ ":functionalize_control_flow",
":tf2xla_proto",
":tf2xla_util",
":xla_compiler",
@@ -188,7 +189,6 @@ cc_library(
deps = [
":common",
":dump_graph",
- ":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
":side_effect_util",
@@ -284,6 +284,7 @@ cc_library(
deps = [
":sharding_util",
":tf2xla_proto",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@@ -479,6 +480,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -506,12 +508,24 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
+ name = "functionalize_control_flow_pass_registration",
+ srcs = [
+ "functionalize_control_flow_pass_registration.cc",
+ ],
+ deps = [
+ ":functionalize_control_flow",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "functionalize_while",
srcs = [
"functionalize_while.cc",
@@ -520,6 +534,7 @@ cc_library(
"functionalize_while.h",
],
deps = [
+ ":functionalize_cond",
":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
@@ -530,6 +545,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index 3ad1d1d5b4..ca64f3f226 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
using xla::StatusOr;
@@ -638,7 +639,7 @@ Status Conditional::ExtractBodies(Graph* graph) {
Status Conditional::BuildIfNode(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "Build cond function for " << name();
- NodeDefBuilder builder(name(), "If");
+ NodeDefBuilder builder(name(), "If", library);
const string branch_name[] = {"else_branch", "then_branch"};
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
int branch_index = static_cast<int>(branch);
@@ -1284,6 +1285,13 @@ Status FunctionalizeCond::FunctionalizeInternal() {
std::vector<int> switch_ids;
std::vector<Node*> merge_order;
DFS(*graph_, nullptr, [&](Node* n) {
+ // Nodes marked with _xla_outside_compilation are skipped, because they need
+ // to be executed on host with regular TF executor, which does not support
+ // XlaIf/XlaWhile.
+ if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) {
+ return;
+ }
+
if (IsSwitch(n)) {
switch_ids.push_back(n->id());
}
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 5932be4e52..f792c52032 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,16 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -68,4 +73,146 @@ Status FunctionalizeControlFlow(Graph* graph,
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
}
+Status FunctionalizeControlFlowForFunction(
+ const string& func_name, const string& new_func_name,
+ const protobuf::Map<string, tensorflow::AttrValue>& attrs,
+ FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
+ std::map<string, string>* canonicalized_name_to_new_name) {
+ // Convert the function to Graph.
+ FunctionLibraryRuntime::Handle handle;
+ TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
+ Status ret_status = Status::OK();
+ auto cleanup_handle = gtl::MakeCleanup([&]() {
+ auto s = flr->ReleaseHandle(handle);
+ if (!s.ok()) {
+ ret_status.Update(s);
+ }
+ });
+ const FunctionBody* body = flr->GetFunctionBody(handle);
+ const FunctionDef& fdef = body->fdef;
+
+ // If any node has associated functions, functionalize them first.
+ // Gather nodes with associated functions first, because rewriting those nodes
+ // might involve node deletion/addition. Avoid modifying nodes while iterating
+ // it.
+ std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
+ nodes_to_associated_functions;
+ for (auto* n : body->graph->nodes()) {
+ auto associated_functions = GetAssociatedFunctions(*n, flr);
+ if (!associated_functions.empty()) {
+ nodes_to_associated_functions.push_back({n, associated_functions});
+ }
+ }
+ for (auto iter : nodes_to_associated_functions) {
+ Node* n = iter.first;
+ auto associated_functions = iter.second;
+ for (auto& associated_function : associated_functions) {
+ string name = associated_function.func_name();
+ string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+ auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
+ string new_name;
+ if (iter != canonicalized_name_to_new_name->end()) {
+ // If we already functionalized this function, skip functionalization
+ // but still rewrite the node.
+ new_name = iter->second;
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+ name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
+ (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ }
+ // Notice that if "n" is a function call, RewriteAssociatedFunction() will
+ // delete it and create a new node instead, making "n" an invalid pointer.
+ // That's fine because in that case, associated_functions will only have
+ // one member and the loop will only run once.
+ TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+ body->graph, n, fld, associated_function, new_name));
+ }
+ }
+
+ // Functionalize the function body.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+ *body->graph, fld);
+ }
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
+ *body->graph, fld);
+ }
+ FunctionDef functionalized_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
+
+ // Copy signature and ret from original FunctionDef.
+ *functionalized_fdef.mutable_signature() = fdef.signature();
+ *functionalized_fdef.mutable_ret() = fdef.ret();
+ functionalized_fdef.mutable_signature()->set_name(new_func_name);
+
+ // Add rewritten FunctionDef into library.
+ if (func_name == new_func_name) {
+ VLOG(2) << "Replacing function " << func_name;
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(new_func_name, functionalized_fdef));
+ } else {
+ VLOG(2) << "Adding function " << new_func_name;
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ }
+
+ return ret_status;
+}
+
+Status FunctionalizeControlFlowPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph,
+ options.flib_def);
+ }
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
+ new ProcessFunctionLibraryRuntime(
+ /*device_mgr=*/nullptr, options.session_options->env,
+ TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
+ FunctionLibraryRuntime* flr =
+ pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+
+ // Find XLA compile ops and its corresponding FunctionDef.
+ static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
+ new std::map<string, string>{
+ {"TPUCompile", "function"},
+ {"XlaLaunch", "function"},
+ };
+ std::map<string, string> canonicalized_name_to_new_name;
+ for (Node* n : graph->nodes()) {
+ auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
+ if (it == kNodeTypeToFunctionAttrMapping->end()) {
+ continue;
+ }
+ const string func_attr = it->second;
+ if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) !=
+ kNodeTypeToFunctionAttrMapping->end()) {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
+ VLOG(2) << "Graph has node " << n->type_string()
+ << ". Corresponding function: " << func.name();
+ string new_func_name = options.flib_def->UniqueFunctionName(
+ absl::StrCat(func.name(), "_f15n_"));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+ func.name(), new_func_name, func.attr(), options.flib_def, flr,
+ &canonicalized_name_to_new_name));
+ n->ClearAttr(func_attr);
+ func.set_name(new_func_name);
+ n->AddAttr(func_attr, func);
+ }
+ }
+
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph,
+ options.flib_def);
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index 55600f2a8b..f1cbcdf617 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@@ -32,6 +33,18 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library);
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (XlaIf/XlaWhile).
+//
+// Notice that control flow structure marked with _xla_outside_compilation are
+// skipped, because they need to be executed on host with regular TF executor,
+// which does not support XlaIf/XlaWhile.
+class FunctionalizeControlFlowPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
new file mode 100644
index 0000000000..a10a9d0499
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
@@ -0,0 +1,25 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
+
+namespace tensorflow {
+
+// This pass is required for some AOT backends and all JIT backends, so this
+// file exists as a separate lib and will be linked to both AOT and JIT.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27,
+ FunctionalizeControlFlowPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
index 7f45e3bffa..2173e15e03 100644
--- a/tensorflow/compiler/tf2xla/functionalize_while.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
@@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
}
}
- // Builds the condition and body functions.
+ // Builds the condition and body functions. Notice that we call
+ // FunctionalizeCond() on cond_graph and body_graph because we might have
+ // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
+ // before they are encapsulated in FunctionDef.
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
+ FixupSourceAndSinkEdges(cond_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
+ FixupSourceAndSinkEdges(body_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
VLOG(2) << "Frame " << frame->name << " condition: "
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
@@ -510,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// Builds a While operator.
NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
+ NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library);
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
@@ -641,8 +650,14 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
continue;
}
- TF_RETURN_IF_ERROR(
- FunctionalizeLoop(lookup_library, graph, frame, library));
+ // Nodes marked with _xla_outside_compilation are skipped, because they need
+ // to be executed on host with regular TF executor, which does not support
+ // XlaIf/XlaWhile.
+ string name;
+ if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) {
+ TF_RETURN_IF_ERROR(
+ FunctionalizeLoop(lookup_library, graph, frame, library));
+ }
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 82e9eef005..c019a28e89 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 7dbe3a0b58..b22d53805d 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
second_copy_def, g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
+
+ // Functionalize control flow.
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def));
+ // After control flow functionalization, we might have more FunctionDef's
+ // (then/else branch, loop body). Add them to the graph.
+ TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto()));
+
*graph = std::move(g);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 211caf8736..d6f42bac86 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -25,9 +25,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind,
} // namespace
+const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
+
Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
for (const tf2xla::Feed& feed : config.feed()) {
@@ -323,4 +328,101 @@ uint32 GetXLARandomSeed() {
return counter.fetch_add(2);
}
+// TODO(b/77601805): add tests for associated function related stuff.
+bool HasAssociatedFunction(const NodeDef& node_def,
+ FunctionLibraryRuntime* flr) {
+ if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) {
+ return true;
+ }
+
+ if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
+ // Skip gradient op. Gradient op has "f" attr, which is set to the function
+ // we are getting gradient for. That function is not associated with the op.
+ return false;
+ }
+
+ for (const auto& iter : node_def.attr()) {
+ if (iter.second.has_func()) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+ const Node& node, FunctionLibraryRuntime* flr) {
+ std::vector<AssociatedFunctionInfo> results;
+ const string& op = node.type_string();
+ if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
+ // This is a function call node.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
+ // Skip gradient op. Gradient op has "f" attr, which is set to the function
+ // we are getting gradient for. That function is not associated with the op.
+ } else {
+ // Collect all function attrs for the node.
+ for (auto& iter : node.attrs()) {
+ if (iter.second.has_func()) {
+ VLOG(2) << "Found function attr for node " << node.name() << ": "
+ << iter.first << " = " << iter.second.func().name();
+ results.emplace_back(AssociatedFunctionInfo(
+ iter.second.func().name(), iter.second.func().attr(), iter.first));
+ }
+ }
+ }
+ return results;
+}
+
+Status RewriteAssociatedFunction(
+ Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+ const AssociatedFunctionInfo& associated_function,
+ const string& rewritten_function_name) {
+ switch (associated_function.type()) {
+ case AssociatedFunctionInfo::kFunctionCallNode: {
+ // Change this node to call the new function.
+ NodeDefBuilder builder(node->name(), rewritten_function_name, fld);
+ for (auto attr : node->attrs()) {
+ builder.Attr(attr.first, attr.second);
+ }
+ for (int i = 0; i < node->num_inputs(); i++) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
+ builder.Input(input_node->name(), i, node->input_type(i));
+ }
+ builder.Device(node->assigned_device_name().empty()
+ ? node->requested_device()
+ : node->assigned_device_name());
+ NodeDef node_def;
+ TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
+ Status s;
+ Node* new_node = graph->AddNode(node_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ for (auto edge : node->in_edges()) {
+ graph->AddEdge(edge->src(), edge->src_output(), new_node,
+ edge->dst_input());
+ }
+ for (auto edge : node->out_edges()) {
+ graph->AddEdge(new_node, edge->src_output(), edge->dst(),
+ edge->dst_input());
+ }
+ graph->RemoveNode(node);
+ break;
+ }
+ case AssociatedFunctionInfo::kFunctionAttr: {
+ // Change function attr to rewritten functions.
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
+ node->ClearAttr(associated_function.attr_name());
+ func.set_name(rewritten_function_name);
+ node->AddAttr(associated_function.attr_name(), func);
+ break;
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index a29e764466..6065d0bb9a 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -59,6 +60,67 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
+// Indicates how a FunctionDef is associated with a graph node (e.g. the node is
+// a function call, or the node has function attrs).
+class AssociatedFunctionInfo {
+ public:
+ enum AssociatedFunctionType {
+ kFunctionCallNode = 0,
+ kFunctionAttr = 1,
+ };
+
+ // The node is a function call.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
+ : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
+
+ // The function is an attr of the node.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
+ const string& attr_name)
+ : type_(kFunctionAttr),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
+ AssociatedFunctionType type() const { return type_; }
+
+ const string& func_name() const { return func_name_; }
+
+ const string& attr_name() const { return attr_name_; }
+
+ const AttrValueMap& attrs() const { return attrs_; }
+
+ private:
+ // Available for all instances.
+ AssociatedFunctionType type_;
+ string func_name_;
+ AttrValueMap attrs_;
+
+ // Only available if the function is defined in an attr.
+ string attr_name_;
+};
+
+// Returns if the NodeDef has associated function.
+bool HasAssociatedFunction(const NodeDef& node_def,
+ FunctionLibraryRuntime* flr);
+
+// Gets functions associated with the node. Current cases:
+// 1. For function call node, its function name;
+// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+ const Node& node, FunctionLibraryRuntime* flr);
+
+// Changes associated functions for the node. Current cases:
+// 1. For function call node, creates a new node with the new function name and
+// remove the old node;
+// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+Status RewriteAssociatedFunction(
+ Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+ const AssociatedFunctionInfo& associated_function,
+ const string& rewritten_function_name);
+
+// Attribute to mark nodes to be executed on host.
+extern const char kXlaOutsideCompilationAttrName[];
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index dcb455779d..105f3b61d5 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
@@ -150,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
TF_RETURN_WITH_CONTEXT_IF_ERROR(
GetFunctionBody(function, flib_runtime_, fbody),
"Local lookup failed with: ", status.error_message());
+ VLOG(4) << "Function " << function.name() << " in flib_runtime_";
+ } else {
+ VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
}
return Status::OK();
}
@@ -743,18 +745,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< dump_graph::DumpGraphToFile(
- absl::StrCat("xla_compile_graph_", name), *graph);
+ absl::StrCat("xla_compile_graph_", name), *graph,
+ flib_runtime_->GetFunctionLibraryDefinition());
}
// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);
- // Converts Tensorflow's graph control-flow constructs into functional
- // control-flow that can be compiled into XLA code.
- TF_RETURN_IF_ERROR(
- FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
- graph.get(), local_flib_def_.get()));
-
// Detect invalid nodes.
// FunctionalizeControlFlow may remove some nodes from the graph.
TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 70efa7781d..100b10cd83 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -1219,25 +1219,8 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
- status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
- std::move(graph_copy), args, &result);
- ASSERT_FALSE(status.ok());
- EXPECT_TRUE(
- absl::StrContains(status.error_message(),
- "The following nodes are unreachable "
- "from the source in the graph: {{node NoOp}}"))
- << status.error_message();
- }
-
- // Fix control edges for NoOp.
- {
- std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
- CopyGraph(*graph, graph_copy.get());
- EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
- XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result));
- EXPECT_EQ(0, result.resource_updates.size());
}
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 26f32677af..d979353d2f 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp(
return default_registry_->LookUp(op, op_reg_data);
}
+string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
+ tf_shared_lock l(mu_);
+ int index = 0;
+ string name = strings::StrCat(prefix, index);
+ while (function_defs_.find(name) != function_defs_.end()) {
+ ++index;
+ name = strings::StrCat(prefix, index);
+ }
+ return name;
+}
+
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 03296a7761..e01eb7503d 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const OpRegistrationData** op_reg_data) const override
LOCKS_EXCLUDED(mu_);
+ // Generates new function name with the specified prefix that is unique
+ // across this library.
+ string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_);
+
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
static constexpr const char* const kArgOp = "_Arg";