aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-09 09:50:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-09 09:54:26 -0700
commitb40ace8f28315431e3435647ce39cc7b24c20bfd (patch)
tree94c8567f43faec1411ae66c157b2ae13ce658838
parentd31f360e1574553ed23b8d483512a2065ac426eb (diff)
Automated rollback of commit a3776a234f555213aafcf41f49a42a8a9448c4ac
PiperOrigin-RevId: 212182923
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc12
-rw-r--r--tensorflow/compiler/tf2xla/BUILD18
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc10
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc133
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.h13
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc25
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc25
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc1
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc102
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h62
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc17
-rw-r--r--tensorflow/core/framework/function.cc11
-rw-r--r--tensorflow/core/framework/function.h4
16 files changed, 32 insertions, 423 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 7d5db713f6..a989f15a1c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -265,7 +265,6 @@ cc_library(
srcs = ["jit_compilation_pass_registration.cc"],
deps = [
":compilation_passes",
- "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/core:core_cpu_internal",
],
alwayslink = 1,
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 5dcf754969..c37b6112cc 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -21,18 +21,6 @@ limitations under the License.
namespace tensorflow {
-// PRE_PLACEMENT passes:
-
-// from
-// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
-// FunctionalizeControlFlowPass: 27
-//
-// This pass looks at the graph and all associated FunctionDefs, and turns
-// traditional control flow structure (Switch/Merge/etc.) into functional
-// control flow structure (XlaIf/XlaWhile). Following passes must
-// handle those FunctionDef correctly.
-
-// POST_REWRITE_FOR_EXEC passes:
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index b28ffaf8a4..3821dced63 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -76,7 +76,6 @@ cc_library(
deps = [
":common",
":dump_graph",
- ":functionalize_control_flow",
":tf2xla_proto",
":tf2xla_util",
":xla_compiler",
@@ -189,6 +188,7 @@ cc_library(
deps = [
":common",
":dump_graph",
+ ":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
":side_effect_util",
@@ -285,7 +285,6 @@ cc_library(
deps = [
":sharding_util",
":tf2xla_proto",
- "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@@ -481,7 +480,6 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
- "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -509,24 +507,12 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
- "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
- name = "functionalize_control_flow_pass_registration",
- srcs = [
- "functionalize_control_flow_pass_registration.cc",
- ],
- deps = [
- ":functionalize_control_flow",
- ],
- alwayslink = 1,
-)
-
-cc_library(
name = "functionalize_while",
srcs = [
"functionalize_while.cc",
@@ -535,7 +521,6 @@ cc_library(
"functionalize_while.h",
],
deps = [
- ":functionalize_cond",
":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
@@ -546,7 +531,6 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
- "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index 55439e77a6..0911550f1f 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/strings/strcat.h"
using xla::StatusOr;
@@ -643,7 +642,7 @@ Status Conditional::ExtractBodies(Graph* graph) {
Status Conditional::BuildIfNode(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "Build cond function for " << name();
- NodeDefBuilder builder(name(), "If", library);
+ NodeDefBuilder builder(name(), "If");
const string branch_name[] = {"else_branch", "then_branch"};
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
int branch_index = static_cast<int>(branch);
@@ -1253,13 +1252,6 @@ Status FunctionalizeCond::FunctionalizeInternal() {
std::vector<int> switch_ids;
std::vector<Node*> merge_order;
DFS(*graph_, nullptr, [&](Node* n) {
- // Nodes marked with _xla_outside_compilation are skipped, because they need
- // to be executed on host with regular TF executor, which does not support
- // XlaIf/XlaWhile.
- if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) {
- return;
- }
-
if (IsSwitch(n)) {
switch_ids.push_back(n->id());
}
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 622767f68d..5932be4e52 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,16 +31,11 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/public/session_options.h"
-#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -73,132 +68,4 @@ Status FunctionalizeControlFlow(Graph* graph,
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
}
-Status FunctionalizeControlFlowForFunction(
- const string& func_name, const string& new_func_name,
- const protobuf::Map<string, tensorflow::AttrValue>& attrs,
- FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
- std::map<string, string>* canonicalized_name_to_new_name) {
- // Convert the function to Graph.
- FunctionLibraryRuntime::Handle handle;
- TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
- Status ret_status = Status::OK();
- auto cleanup_handle = gtl::MakeCleanup([&]() {
- auto s = flr->ReleaseHandle(handle);
- if (!s.ok()) {
- ret_status.Update(s);
- }
- });
- const FunctionBody* body = flr->GetFunctionBody(handle);
- const FunctionDef& fdef = body->fdef;
-
- // If any node has associated functions, functionalize them first.
- for (auto* n : body->graph->nodes()) {
- auto associated_functions = GetAssociatedFunctions(*n, flr);
- for (auto& associated_function : associated_functions) {
- string name = associated_function.func_name();
- string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
- // If we already functionalized this function, skip it.
- auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
- if (iter != canonicalized_name_to_new_name->end()) {
- continue;
- }
-
- string new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
- TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
- name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
- (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
- // Notice that if "n" is a function call, RewriteAssociatedFunction() will
- // delete it and create a new node instead, making "n" an invalid pointer.
- // That's fine because in that case, associated_functions will only have
- // one member and the loop will only run once.
- TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- body->graph, n, fld, associated_function, new_name));
- }
- }
-
- // Functionalize the function body.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *body->graph, fld);
- }
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *body->graph, fld);
- }
- FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
-
- // Copy signature and ret from original FunctionDef.
- *functionalized_fdef.mutable_signature() = fdef.signature();
- *functionalized_fdef.mutable_ret() = fdef.ret();
- functionalized_fdef.mutable_signature()->set_name(new_func_name);
-
- // Add rewritten FunctionDef into library.
- if (func_name == new_func_name) {
- VLOG(2) << "Replacing function " << func_name;
- TF_RETURN_IF_ERROR(
- fld->ReplaceFunction(new_func_name, functionalized_fdef));
- } else {
- VLOG(2) << "Adding function " << new_func_name;
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
- }
-
- return ret_status;
-}
-
-Status FunctionalizeControlFlowPass::Run(
- const GraphOptimizationPassOptions& options) {
- Graph* graph = options.graph->get();
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph,
- options.flib_def);
- }
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
- new ProcessFunctionLibraryRuntime(
- /*device_mgr=*/nullptr, options.session_options->env,
- TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
- FunctionLibraryRuntime* flr =
- pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
-
- // Find XLA compile ops and its corresponding FunctionDef.
- static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
- new std::map<string, string>{
- {"TPUCompile", "function"},
- {"XlaLaunch", "function"},
- };
- std::map<string, string> canonicalized_name_to_new_name;
- for (Node* n : graph->nodes()) {
- auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
- if (it == kNodeTypeToFunctionAttrMapping->end()) {
- continue;
- }
- const string func_attr = it->second;
- if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) !=
- kNodeTypeToFunctionAttrMapping->end()) {
- NameAttrList func;
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
- VLOG(2) << "Graph has node " << n->type_string()
- << ". Corresponding function: " << func.name();
- string new_func_name = options.flib_def->UniqueFunctionName(
- absl::StrCat(func.name(), "_f15n_"));
- TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
- func.name(), new_func_name, func.attr(), options.flib_def, flr,
- &canonicalized_name_to_new_name));
- n->ClearAttr(func_attr);
- func.set_name(new_func_name);
- n->AddAttr(func_attr, func);
- }
- }
-
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph,
- options.flib_def);
- }
- return Status::OK();
-}
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index f1cbcdf617..55600f2a8b 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@@ -33,18 +32,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library);
-// This pass looks at the graph and all associated FunctionDefs, and turns
-// traditional control flow structure (Switch/Merge/etc.) into functional
-// control flow structure (XlaIf/XlaWhile).
-//
-// Notice that control flow structure marked with _xla_outside_compilation are
-// skipped, because they need to be executed on host with regular TF executor,
-// which does not support XlaIf/XlaWhile.
-class FunctionalizeControlFlowPass : public GraphOptimizationPass {
- public:
- Status Run(const GraphOptimizationPassOptions& options) override;
-};
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
deleted file mode 100644
index a10a9d0499..0000000000
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
-
-namespace tensorflow {
-
-// This pass is required for some AOT backends and all JIT backends, so this
-// file exists as a separate lib and will be linked to both AOT and JIT.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27,
- FunctionalizeControlFlowPass);
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
index f905c6a0fc..7f45e3bffa 100644
--- a/tensorflow/compiler/tf2xla/functionalize_while.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -35,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
@@ -475,21 +473,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
}
}
- // Builds the condition and body functions. Notice that we call
- // FunctionalizeCond() on cond_graph and body_graph because we might have
- // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
- // before they are encapsulated in FunctionDef.
- // TODO(b/114485797): current logic does not functionalize while loop in
- // another loop cond.
+ // Builds the condition and body functions.
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
- FixupSourceAndSinkEdges(cond_graph.get());
- TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
- FixupSourceAndSinkEdges(body_graph.get());
- TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
VLOG(2) << "Frame " << frame->name << " condition: "
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
@@ -521,7 +510,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// Builds a While operator.
NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library);
+ NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
@@ -652,14 +641,8 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
continue;
}
- // Nodes marked with _xla_outside_compilation are skipped, because they need
- // to be executed on host with regular TF executor, which does not support
- // XlaIf/XlaWhile.
- string name;
- if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) {
- TF_RETURN_IF_ERROR(
- FunctionalizeLoop(lookup_library, graph, frame, library));
- }
+ TF_RETURN_IF_ERROR(
+ FunctionalizeLoop(lookup_library, graph, frame, library));
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index fa25a230b0..bc2e640559 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index b22d53805d..7dbe3a0b58 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -341,13 +340,6 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
second_copy_def, g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
-
- // Functionalize control flow.
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def));
- // After control flow functionalization, we might have more FunctionDef's
- // (then/else branch, loop body). Add them to the graph.
- TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto()));
-
*graph = std::move(g);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index d6f42bac86..211caf8736 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -25,12 +25,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -78,8 +75,6 @@ Status CheckFeedFetchNameConflicts(const string& kind,
} // namespace
-const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
-
Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
for (const tf2xla::Feed& feed : config.feed()) {
@@ -328,101 +323,4 @@ uint32 GetXLARandomSeed() {
return counter.fetch_add(2);
}
-// TODO(b/77601805): add tests for associated function related stuff.
-bool HasAssociatedFunction(const NodeDef& node_def,
- FunctionLibraryRuntime* flr) {
- if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) {
- return true;
- }
-
- if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
- return false;
- }
-
- for (const auto& iter : node_def.attr()) {
- if (iter.second.has_func()) {
- return true;
- }
- }
-
- return false;
-}
-
-std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
- const Node& node, FunctionLibraryRuntime* flr) {
- std::vector<AssociatedFunctionInfo> results;
- const string& op = node.type_string();
- if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
- // This is a function call node.
- AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
- results.emplace_back(AssociatedFunctionInfo(op, attrs));
- } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
- } else {
- // Collect all function attrs for the node.
- for (auto& iter : node.attrs()) {
- if (iter.second.has_func()) {
- VLOG(2) << "Found function attr for node " << node.name() << ": "
- << iter.first << " = " << iter.second.func().name();
- results.emplace_back(AssociatedFunctionInfo(
- iter.second.func().name(), iter.second.func().attr(), iter.first));
- }
- }
- }
- return results;
-}
-
-Status RewriteAssociatedFunction(
- Graph* graph, Node* node, FunctionLibraryDefinition* fld,
- const AssociatedFunctionInfo& associated_function,
- const string& rewritten_function_name) {
- switch (associated_function.type()) {
- case AssociatedFunctionInfo::kFunctionCallNode: {
- // Change this node to call the new function.
- NodeDefBuilder builder(node->name(), rewritten_function_name, fld);
- for (auto attr : node->attrs()) {
- builder.Attr(attr.first, attr.second);
- }
- for (int i = 0; i < node->num_inputs(); i++) {
- Node* input_node;
- TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
- builder.Input(input_node->name(), i, node->input_type(i));
- }
- builder.Device(node->assigned_device_name().empty()
- ? node->requested_device()
- : node->assigned_device_name());
- NodeDef node_def;
- TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
- Status s;
- Node* new_node = graph->AddNode(node_def, &s);
- TF_RETURN_IF_ERROR(s);
- for (auto edge : node->in_edges()) {
- graph->AddEdge(edge->src(), edge->src_output(), new_node,
- edge->dst_input());
- }
- for (auto edge : node->out_edges()) {
- graph->AddEdge(new_node, edge->src_output(), edge->dst(),
- edge->dst_input());
- }
- graph->RemoveNode(node);
- break;
- }
- case AssociatedFunctionInfo::kFunctionAttr: {
- // Change function attr to rewritten functions.
- NameAttrList func;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
- node->ClearAttr(associated_function.attr_name());
- func.set_name(rewritten_function_name);
- node->AddAttr(associated_function.attr_name(), func);
- break;
- }
- }
-
- return Status::OK();
-}
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 41e70e0658..dcddef8418 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
-#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -61,67 +60,6 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
-// Indicates how a FunctionDef is associated with a graph node (e.g. the node is
-// a function call, or the node has function attrs).
-class AssociatedFunctionInfo {
- public:
- enum AssociatedFunctionType {
- kFunctionCallNode = 0,
- kFunctionAttr = 1,
- };
-
- // The node is a function call.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
- : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
-
- // The function is an attr of the node.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
- const string& attr_name)
- : type_(kFunctionAttr),
- func_name_(func_name),
- attrs_(attrs),
- attr_name_(attr_name) {}
-
- AssociatedFunctionType type() const { return type_; }
-
- const string& func_name() const { return func_name_; }
-
- const string& attr_name() const { return attr_name_; }
-
- const AttrValueMap& attrs() const { return attrs_; }
-
- private:
- // Available for all instances.
- AssociatedFunctionType type_;
- string func_name_;
- AttrValueMap attrs_;
-
- // Only available if the function is defined in an attr.
- string attr_name_;
-};
-
-// Returns if the NodeDef has associated function.
-bool HasAssociatedFunction(const NodeDef& node_def,
- FunctionLibraryRuntime* flr);
-
-// Gets functions associated with the node. Current cases:
-// 1. For function call node, its function name;
-// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
-std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
- const Node& node, FunctionLibraryRuntime* flr);
-
-// Changes associated functions for the node. Current cases:
-// 1. For function call node, creates a new node with the new function name and
-// remove the old node;
-// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
-Status RewriteAssociatedFunction(
- Graph* graph, Node* node, FunctionLibraryDefinition* fld,
- const AssociatedFunctionInfo& associated_function,
- const string& rewritten_function_name);
-
-// Attribute to mark nodes to be executed on host.
-extern const char kXlaOutsideCompilationAttrName[];
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 105f3b61d5..dcb455779d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
@@ -149,9 +150,6 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
TF_RETURN_WITH_CONTEXT_IF_ERROR(
GetFunctionBody(function, flib_runtime_, fbody),
"Local lookup failed with: ", status.error_message());
- VLOG(4) << "Function " << function.name() << " in flib_runtime_";
- } else {
- VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
}
return Status::OK();
}
@@ -745,13 +743,18 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< dump_graph::DumpGraphToFile(
- absl::StrCat("xla_compile_graph_", name), *graph,
- flib_runtime_->GetFunctionLibraryDefinition());
+ absl::StrCat("xla_compile_graph_", name), *graph);
}
// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);
+ // Converts Tensorflow's graph control-flow constructs into functional
+ // control-flow that can be compiled into XLA code.
+ TF_RETURN_IF_ERROR(
+ FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
+ graph.get(), local_flib_def_.get()));
+
// Detect invalid nodes.
// FunctionalizeControlFlow may remove some nodes from the graph.
TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 42de6bacd6..40ce9fb41c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -1255,8 +1255,25 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
+ std::move(graph_copy), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(
+ absl::StrContains(status.error_message(),
+ "The following nodes are unreachable "
+ "from the source in the graph: {{node NoOp}}"))
+ << status.error_message();
+ }
+
+ // Fix control edges for NoOp.
+ {
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
+ XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result));
+ EXPECT_EQ(0, result.resource_updates.size());
}
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index d979353d2f..26f32677af 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1154,17 +1154,6 @@ Status FunctionLibraryDefinition::LookUp(
return default_registry_->LookUp(op, op_reg_data);
}
-string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
- tf_shared_lock l(mu_);
- int index = 0;
- string name = strings::StrCat(prefix, index);
- while (function_defs_.find(name) != function_defs_.end()) {
- ++index;
- name = strings::StrCat(prefix, index);
- }
- return name;
-}
-
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e01eb7503d..03296a7761 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -358,10 +358,6 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const OpRegistrationData** op_reg_data) const override
LOCKS_EXCLUDED(mu_);
- // Generates new function name with the specified prefix that is unique
- // across this library.
- string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_);
-
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
static constexpr const char* const kArgOp = "_Arg";