diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-12 10:03:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 10:07:48 -0700 |
commit | 37ddb13ece32500bf87af5d8b8493be1c77781de (patch) | |
tree | d3b5daf4c29b5402a0db9e3f316665b5c1d1a9b1 | |
parent | 26509bf4e202c09da4f0b00d43ebddf87368a0f2 (diff) |
Roll forward change "Move control flow functionalization as a graph optimization pass, instead of a step in XlaCompiler.".
PiperOrigin-RevId: 212657932
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/jit/jit_compilation_pass_registration.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 18 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_cond.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 147 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.h | 13 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_while.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/graph_compiler.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla_util.cc | 102 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla_util.h | 62 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/framework/function.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 4 |
16 files changed, 435 insertions, 32 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a989f15a1c..7d5db713f6 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -265,6 +265,7 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index c37b6112cc..5dcf754969 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -21,6 +21,18 @@ limitations under the License. namespace tensorflow { +// PRE_PLACEMENT passes: + +// from +// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +// FunctionalizeControlFlowPass: 27 +// +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). Following passes must +// handle those FunctionDef correctly. + +// POST_REWRITE_FOR_EXEC passes: REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ab289a2b6c..e29a4c0603 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -76,6 +76,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -188,7 +189,6 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", @@ -284,6 +284,7 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -479,6 +480,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -506,12 +508,24 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) cc_library( + name = "functionalize_control_flow_pass_registration", + srcs = [ + "functionalize_control_flow_pass_registration.cc", + ], + deps = [ + ":functionalize_control_flow", + ], + alwayslink = 1, +) + +cc_library( name = "functionalize_while", srcs = [ "functionalize_while.cc", @@ -520,6 +534,7 @@ cc_library( "functionalize_while.h", ], deps = [ + ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", @@ -530,6 +545,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 3ad1d1d5b4..ca64f3f226 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; @@ -638,7 +639,7 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If"); + NodeDefBuilder builder(name(), "If", library); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast<int>(branch); @@ -1284,6 +1285,13 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::vector<int> switch_ids; std::vector<Node*> merge_order; DFS(*graph_, nullptr, [&](Node* n) { + // Nodes marked with _xla_outside_compilation are skipped, because they need + // to be executed on host with regular TF executor, which does not support + // XlaIf/XlaWhile. + if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) { + return; + } + if (IsSwitch(n)) { switch_ids.push_back(n->id()); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5932be4e52..f792c52032 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -68,4 +73,146 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map<string, tensorflow::AttrValue>& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + std::map<string, string>* canonicalized_name_to_new_name) { + // Convert the function to Graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + const FunctionDef& fdef = body->fdef; + + // If any node has associated functions, functionalize them first. + // Gather nodes with associated functions first, because rewriting those nodes + // might involve node deletion/addition. Avoid modifying nodes while iterating + // it. + std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>> + nodes_to_associated_functions; + for (auto* n : body->graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, flr); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (auto iter : nodes_to_associated_functions) { + Node* n = iter.first; + auto associated_functions = iter.second; + for (auto& associated_function : associated_functions) { + string name = associated_function.func_name(); + string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); + auto iter = canonicalized_name_to_new_name->find(canonicalized_name); + string new_name; + if (iter != canonicalized_name_to_new_name->end()) { + // If we already functionalized this function, skip functionalization + // but still rewrite the node. + new_name = iter->second; + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } + // Notice that if "n" is a function call, RewriteAssociatedFunction() will + // delete it and create a new node instead, making "n" an invalid pointer. + // That's fine because in that case, associated_functions will only have + // one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + body->graph, n, fld, associated_function, new_name)); + } + } + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *body->graph, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *body->graph, fld); + } + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); + + // Copy signature and ret from original FunctionDef. + *functionalized_fdef.mutable_signature() = fdef.signature(); + *functionalized_fdef.mutable_ret() = fdef.ret(); + functionalized_fdef.mutable_signature()->set_name(new_func_name); + + // Add rewritten FunctionDef into library. + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + + return ret_status; +} + +Status FunctionalizeControlFlowPass::Run( + const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); + } + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, options.session_options->env, + TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + // Find XLA compile ops and its corresponding FunctionDef. + static std::map<string, string>* kNodeTypeToFunctionAttrMapping = + new std::map<string, string>{ + {"TPUCompile", "function"}, + {"XlaLaunch", "function"}, + }; + std::map<string, string> canonicalized_name_to_new_name; + for (Node* n : graph->nodes()) { + auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); + if (it == kNodeTypeToFunctionAttrMapping->end()) { + continue; + } + const string func_attr = it->second; + if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != + kNodeTypeToFunctionAttrMapping->end()) { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name)); + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } + } + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 55600f2a8b..f1cbcdf617 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -32,6 +33,18 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). +// +// Notice that control flow structure marked with _xla_outside_compilation are +// skipped, because they need to be executed on host with regular TF executor, +// which does not support XlaIf/XlaWhile. +class FunctionalizeControlFlowPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc new file mode 100644 index 0000000000..a10a9d0499 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +namespace tensorflow { + +// This pass is required for some AOT backends and all JIT backends, so this +// file exists as a separate lib and will be linked to both AOT and JIT. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, + FunctionalizeControlFlowPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 7f45e3bffa..2173e15e03 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { @@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } - // Builds the condition and body functions. + // Builds the condition and body functions. Notice that we call + // FunctionalizeCond() on cond_graph and body_graph because we might have + // unfunctionalized "if" in cond_graph and body_graph. Functionalize them + // before they are encapsulated in FunctionDef. std::unique_ptr<Graph> cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + FixupSourceAndSinkEdges(cond_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); DataTypeVector arg_types; std::unique_ptr<Graph> body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + FixupSourceAndSinkEdges(body_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) @@ -510,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); @@ -641,8 +650,14 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, continue; } - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); + // Nodes marked with _xla_outside_compilation are skipped, because they need + // to be executed on host with regular TF executor, which does not support + // XlaIf/XlaWhile. + string name; + if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) { + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); + } // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 82e9eef005..c019a28e89 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 7dbe3a0b58..b22d53805d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); + + // Functionalize control flow. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); + // After control flow functionalization, we might have more FunctionDef's + // (then/else branch, loop body). Add them to the graph. + TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); + *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 211caf8736..d6f42bac86 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -25,9 +25,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace +const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; + Status ValidateConfig(const tf2xla::Config& config) { std::set<string> names; for (const tf2xla::Feed& feed : config.feed()) { @@ -323,4 +328,101 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } +// TODO(b/77601805): add tests for associated function related stuff. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr) { + if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { + return true; + } + + if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + return false; + } + + for (const auto& iter : node_def.attr()) { + if (iter.second.has_func()) { + return true; + } + } + + return false; +} + +std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr) { + std::vector<AssociatedFunctionInfo> results; + const string& op = node.type_string(); + if (flr->GetFunctionLibraryDefinition()->Contains(op)) { + // This is a function call node. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo(op, attrs)); + } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + } else { + // Collect all function attrs for the node. + for (auto& iter : node.attrs()) { + if (iter.second.has_func()) { + VLOG(2) << "Found function attr for node " << node.name() << ": " + << iter.first << " = " << iter.second.func().name(); + results.emplace_back(AssociatedFunctionInfo( + iter.second.func().name(), iter.second.func().attr(), iter.first)); + } + } + } + return results; +} + +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + // Change this node to call the new function. + NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + for (auto attr : node->attrs()) { + builder.Attr(attr.first, attr.second); + } + for (int i = 0; i < node->num_inputs(); i++) { + Node* input_node; + TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); + builder.Input(input_node->name(), i, node->input_type(i)); + } + builder.Device(node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name()); + NodeDef node_def; + TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); + Status s; + Node* new_node = graph->AddNode(node_def, &s); + TF_RETURN_IF_ERROR(s); + for (auto edge : node->in_edges()) { + graph->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input()); + } + for (auto edge : node->out_edges()) { + graph->AddEdge(new_node, edge->src_output(), edge->dst(), + edge->dst_input()); + } + graph->RemoveNode(node); + break; + } + case AssociatedFunctionInfo::kFunctionAttr: { + // Change function attr to rewritten functions. + NameAttrList func; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); + node->ClearAttr(associated_function.attr_name()); + func.set_name(rewritten_function_name); + node->AddAttr(associated_function.attr_name(), func); + break; + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index a29e764466..6065d0bb9a 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -19,6 +19,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -59,6 +60,67 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, // Returns the next random seed to use for seeding xla rng. uint32 GetXLARandomSeed(); +// Indicates how a FunctionDef is associated with a graph node (e.g. the node is +// a function call, or the node has function attrs). +class AssociatedFunctionInfo { + public: + enum AssociatedFunctionType { + kFunctionCallNode = 0, + kFunctionAttr = 1, + }; + + // The node is a function call. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) + : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} + + // The function is an attr of the node. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, + const string& attr_name) + : type_(kFunctionAttr), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + + AssociatedFunctionType type() const { return type_; } + + const string& func_name() const { return func_name_; } + + const string& attr_name() const { return attr_name_; } + + const AttrValueMap& attrs() const { return attrs_; } + + private: + // Available for all instances. + AssociatedFunctionType type_; + string func_name_; + AttrValueMap attrs_; + + // Only available if the function is defined in an attr. + string attr_name_; +}; + +// Returns if the NodeDef has associated function. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr); + +// Gets functions associated with the node. Current cases: +// 1. For function call node, its function name; +// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr); + +// Changes associated functions for the node. Current cases: +// 1. For function call node, creates a new node with the new function name and +// remove the old node; +// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name); + +// Attribute to mark nodes to be executed on host. +extern const char kXlaOutsideCompilationAttrName[]; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index dcb455779d..105f3b61d5 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" @@ -150,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); + VLOG(4) << "Function " << function.name() << " in flib_runtime_"; + } else { + VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -743,18 +745,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - absl::StrCat("xla_compile_graph_", name), *graph); + absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); - // Converts Tensorflow's graph control-flow constructs into functional - // control-flow that can be compiled into XLA code. - TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), - graph.get(), local_flib_def_.get())); - // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 70efa7781d..100b10cd83 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -1219,25 +1219,8 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) - << status.error_message(); - } - - // Fix control edges for NoOp. - { - std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); - CopyGraph(*graph, graph_copy.get()); - EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); - XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result)); - EXPECT_EQ(0, result.resource_updates.size()); } } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 26f32677af..d979353d2f 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } +string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { + tf_shared_lock l(mu_); + int index = 0; + string name = strings::StrCat(prefix, index); + while (function_defs_.find(name) != function_defs_.end()) { + ++index; + name = strings::StrCat(prefix, index); + } + return name; +} + const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( const NodeDef& ndef) const { if (ndef.op() != kGradientOp) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 03296a7761..e01eb7503d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override LOCKS_EXCLUDED(mu_); + // Generates new function name with the specified prefix that is unique + // across this library. + string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_); + // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. static constexpr const char* const kArgOp = "_Arg"; |