diff options
author | 2018-09-09 09:50:03 -0700 | |
---|---|---|
committer | 2018-09-09 09:54:26 -0700 | |
commit | b40ace8f28315431e3435647ce39cc7b24c20bfd (patch) | |
tree | 94c8567f43faec1411ae66c157b2ae13ce658838 | |
parent | d31f360e1574553ed23b8d483512a2065ac426eb (diff) |
Automated rollback of commit a3776a234f555213aafcf41f49a42a8a9448c4ac
PiperOrigin-RevId: 212182923
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/jit/jit_compilation_pass_registration.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 18 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_cond.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 133 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.h | 13 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_while.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/graph_compiler.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla_util.cc | 102 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/tf2xla_util.h | 62 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/framework/function.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 4 |
16 files changed, 32 insertions, 423 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7d5db713f6..a989f15a1c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -265,7 +265,6 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", - "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 5dcf754969..c37b6112cc 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -21,18 +21,6 @@ limitations under the License. namespace tensorflow { -// PRE_PLACEMENT passes: - -// from -// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc -// FunctionalizeControlFlowPass: 27 -// -// This pass looks at the graph and all associated FunctionDefs, and turns -// traditional control flow structure (Switch/Merge/etc.) into functional -// control flow structure (XlaIf/XlaWhile). Following passes must -// handle those FunctionDef correctly. - -// POST_REWRITE_FOR_EXEC passes: REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index b28ffaf8a4..3821dced63 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -76,7 +76,6 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -189,6 +188,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", @@ -285,7 +285,6 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -481,7 +480,6 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -509,24 +507,12 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) cc_library( - name = "functionalize_control_flow_pass_registration", - srcs = [ - "functionalize_control_flow_pass_registration.cc", - ], - deps = [ - ":functionalize_control_flow", - ], - alwayslink = 1, -) - -cc_library( name = "functionalize_while", srcs = [ "functionalize_while.cc", @@ -535,7 +521,6 @@ cc_library( "functionalize_while.h", ], deps = [ - ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", @@ -546,7 +531,6 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 55439e77a6..0911550f1f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; @@ -643,7 +642,7 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If", library); + NodeDefBuilder builder(name(), "If"); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast<int>(branch); @@ -1253,13 +1252,6 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::vector<int> switch_ids; std::vector<Node*> merge_order; DFS(*graph_, nullptr, [&](Node* n) { - // Nodes marked with _xla_outside_compilation are skipped, because they need - // to be executed on host with regular TF executor, which does not support - // XlaIf/XlaWhile. - if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) { - return; - } - if (IsSwitch(n)) { switch_ids.push_back(n->id()); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 622767f68d..5932be4e52 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,16 +31,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -73,132 +68,4 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } -Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map<string, tensorflow::AttrValue>& attrs, - FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map<string, string>* canonicalized_name_to_new_name) { - // Convert the function to Graph. - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = Status::OK(); - auto cleanup_handle = gtl::MakeCleanup([&]() { - auto s = flr->ReleaseHandle(handle); - if (!s.ok()) { - ret_status.Update(s); - } - }); - const FunctionBody* body = flr->GetFunctionBody(handle); - const FunctionDef& fdef = body->fdef; - - // If any node has associated functions, functionalize them first. - for (auto* n : body->graph->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, flr); - for (auto& associated_function : associated_functions) { - string name = associated_function.func_name(); - string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); - // If we already functionalized this function, skip it. - auto iter = canonicalized_name_to_new_name->find(canonicalized_name); - if (iter != canonicalized_name_to_new_name->end()) { - continue; - } - - string new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; - // Notice that if "n" is a function call, RewriteAssociatedFunction() will - // delete it and create a new node instead, making "n" an invalid pointer. - // That's fine because in that case, associated_functions will only have - // one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - body->graph, n, fld, associated_function, new_name)); - } - } - - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *body->graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *body->graph, fld); - } - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); - - // Copy signature and ret from original FunctionDef. - *functionalized_fdef.mutable_signature() = fdef.signature(); - *functionalized_fdef.mutable_ret() = fdef.ret(); - functionalized_fdef.mutable_signature()->set_name(new_func_name); - - // Add rewritten FunctionDef into library. - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); - } - - return ret_status; -} - -Status FunctionalizeControlFlowPass::Run( - const GraphOptimizationPassOptions& options) { - Graph* graph = options.graph->get(); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, - options.flib_def); - } - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( - new ProcessFunctionLibraryRuntime( - /*device_mgr=*/nullptr, options.session_options->env, - TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); - FunctionLibraryRuntime* flr = - pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - - // Find XLA compile ops and its corresponding FunctionDef. - static std::map<string, string>* kNodeTypeToFunctionAttrMapping = - new std::map<string, string>{ - {"TPUCompile", "function"}, - {"XlaLaunch", "function"}, - }; - std::map<string, string> canonicalized_name_to_new_name; - for (Node* n : graph->nodes()) { - auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); - if (it == kNodeTypeToFunctionAttrMapping->end()) { - continue; - } - const string func_attr = it->second; - if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != - kNodeTypeToFunctionAttrMapping->end()) { - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); - VLOG(2) << "Graph has node " << n->type_string() - << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( - absl::StrCat(func.name(), "_f15n_")); - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name)); - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - } - } - - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, - options.flib_def); - } - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index f1cbcdf617..55600f2a8b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -33,18 +32,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); -// This pass looks at the graph and all associated FunctionDefs, and turns -// traditional control flow structure (Switch/Merge/etc.) into functional -// control flow structure (XlaIf/XlaWhile). -// -// Notice that control flow structure marked with _xla_outside_compilation are -// skipped, because they need to be executed on host with regular TF executor, -// which does not support XlaIf/XlaWhile. -class FunctionalizeControlFlowPass : public GraphOptimizationPass { - public: - Status Run(const GraphOptimizationPassOptions& options) override; -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc deleted file mode 100644 index a10a9d0499..0000000000 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" - -namespace tensorflow { - -// This pass is required for some AOT backends and all JIT backends, so this -// file exists as a separate lib and will be linked to both AOT and JIT. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, - FunctionalizeControlFlowPass); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index f905c6a0fc..7f45e3bffa 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -35,7 +34,6 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { @@ -475,21 +473,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } - // Builds the condition and body functions. Notice that we call - // FunctionalizeCond() on cond_graph and body_graph because we might have - // unfunctionalized "if" in cond_graph and body_graph. Functionalize them - // before they are encapsulated in FunctionDef. - // TODO(b/114485797): current logic does not functionalize while loop in - // another loop cond. + // Builds the condition and body functions. std::unique_ptr<Graph> cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); - FixupSourceAndSinkEdges(cond_graph.get()); - TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); DataTypeVector arg_types; std::unique_ptr<Graph> body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); - FixupSourceAndSinkEdges(body_graph.get()); - TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) @@ -521,7 +510,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library); + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); @@ -652,14 +641,8 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, continue; } - // Nodes marked with _xla_outside_compilation are skipped, because they need - // to be executed on host with regular TF executor, which does not support - // XlaIf/XlaWhile. - string name; - if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) { - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); - } + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index fa25a230b0..bc2e640559 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index b22d53805d..7dbe3a0b58 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -341,13 +340,6 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); - - // Functionalize control flow. - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); - // After control flow functionalization, we might have more FunctionDef's - // (then/else branch, loop body). Add them to the graph. - TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); - *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index d6f42bac86..211caf8736 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -25,12 +25,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -78,8 +75,6 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace -const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; - Status ValidateConfig(const tf2xla::Config& config) { std::set<string> names; for (const tf2xla::Feed& feed : config.feed()) { @@ -328,101 +323,4 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } -// TODO(b/77601805): add tests for associated function related stuff. -bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr) { - if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { - return true; - } - - if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. - return false; - } - - for (const auto& iter : node_def.attr()) { - if (iter.second.has_func()) { - return true; - } - } - - return false; -} - -std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr) { - std::vector<AssociatedFunctionInfo> results; - const string& op = node.type_string(); - if (flr->GetFunctionLibraryDefinition()->Contains(op)) { - // This is a function call node. - AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); - results.emplace_back(AssociatedFunctionInfo(op, attrs)); - } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. - } else { - // Collect all function attrs for the node. - for (auto& iter : node.attrs()) { - if (iter.second.has_func()) { - VLOG(2) << "Found function attr for node " << node.name() << ": " - << iter.first << " = " << iter.second.func().name(); - results.emplace_back(AssociatedFunctionInfo( - iter.second.func().name(), iter.second.func().attr(), iter.first)); - } - } - } - return results; -} - -Status RewriteAssociatedFunction( - Graph* graph, Node* node, FunctionLibraryDefinition* fld, - const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name) { - switch (associated_function.type()) { - case AssociatedFunctionInfo::kFunctionCallNode: { - // Change this node to call the new function. - NodeDefBuilder builder(node->name(), rewritten_function_name, fld); - for (auto attr : node->attrs()) { - builder.Attr(attr.first, attr.second); - } - for (int i = 0; i < node->num_inputs(); i++) { - Node* input_node; - TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); - builder.Input(input_node->name(), i, node->input_type(i)); - } - builder.Device(node->assigned_device_name().empty() - ? node->requested_device() - : node->assigned_device_name()); - NodeDef node_def; - TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); - Status s; - Node* new_node = graph->AddNode(node_def, &s); - TF_RETURN_IF_ERROR(s); - for (auto edge : node->in_edges()) { - graph->AddEdge(edge->src(), edge->src_output(), new_node, - edge->dst_input()); - } - for (auto edge : node->out_edges()) { - graph->AddEdge(new_node, edge->src_output(), edge->dst(), - edge->dst_input()); - } - graph->RemoveNode(node); - break; - } - case AssociatedFunctionInfo::kFunctionAttr: { - // Change function attr to rewritten functions. - NameAttrList func; - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); - node->ClearAttr(associated_function.attr_name()); - func.set_name(rewritten_function_name); - node->AddAttr(associated_function.attr_name(), func); - break; - } - } - - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 41e70e0658..dcddef8418 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -20,7 +20,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -61,67 +60,6 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, // Returns the next random seed to use for seeding xla rng. uint32 GetXLARandomSeed(); -// Indicates how a FunctionDef is associated with a graph node (e.g. the node is -// a function call, or the node has function attrs). -class AssociatedFunctionInfo { - public: - enum AssociatedFunctionType { - kFunctionCallNode = 0, - kFunctionAttr = 1, - }; - - // The node is a function call. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) - : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} - - // The function is an attr of the node. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, - const string& attr_name) - : type_(kFunctionAttr), - func_name_(func_name), - attrs_(attrs), - attr_name_(attr_name) {} - - AssociatedFunctionType type() const { return type_; } - - const string& func_name() const { return func_name_; } - - const string& attr_name() const { return attr_name_; } - - const AttrValueMap& attrs() const { return attrs_; } - - private: - // Available for all instances. - AssociatedFunctionType type_; - string func_name_; - AttrValueMap attrs_; - - // Only available if the function is defined in an attr. - string attr_name_; -}; - -// Returns if the NodeDef has associated function. -bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr); - -// Gets functions associated with the node. Current cases: -// 1. For function call node, its function name; -// 2. For nodes like XlaWhile/XlaIf, all their function attributes. -std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr); - -// Changes associated functions for the node. Current cases: -// 1. For function call node, creates a new node with the new function name and -// remove the old node; -// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. -Status RewriteAssociatedFunction( - Graph* graph, Node* node, FunctionLibraryDefinition* fld, - const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name); - -// Attribute to mark nodes to be executed on host. -extern const char kXlaOutsideCompilationAttrName[]; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 105f3b61d5..dcb455779d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" @@ -149,9 +150,6 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); - VLOG(4) << "Function " << function.name() << " in flib_runtime_"; - } else { - VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -745,13 +743,18 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - absl::StrCat("xla_compile_graph_", name), *graph, - flib_runtime_->GetFunctionLibraryDefinition()); + absl::StrCat("xla_compile_graph_", name), *graph); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); + // Converts Tensorflow's graph control-flow constructs into functional + // control-flow that can be compiled into XLA code. + TF_RETURN_IF_ERROR( + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); + // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 42de6bacd6..40ce9fb41c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -1255,8 +1255,25 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + absl::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) + << status.error_message(); + } + + // Fix control edges for NoOp. + { + std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); + XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result)); + EXPECT_EQ(0, result.resource_updates.size()); } } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index d979353d2f..26f32677af 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1154,17 +1154,6 @@ Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } -string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { - tf_shared_lock l(mu_); - int index = 0; - string name = strings::StrCat(prefix, index); - while (function_defs_.find(name) != function_defs_.end()) { - ++index; - name = strings::StrCat(prefix, index); - } - return name; -} - const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( const NodeDef& ndef) const { if (ndef.op() != kGradientOp) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index e01eb7503d..03296a7761 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -358,10 +358,6 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override LOCKS_EXCLUDED(mu_); - // Generates new function name with the specified prefix that is unique - // across this library. - string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_); - // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. static constexpr const char* const kArgOp = "_Arg"; |