aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-15 10:25:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 10:29:10 -0700
commitc9a2034f93981e17eef5f96fbd2894202b8fc2c1 (patch)
tree1cd077b6a364a0337fee93901aa4c3c2aee4014a
parent655c52b014df4a9b7dc8212aabb0bdf20da44107 (diff)
[TF:XLA] Validate the control flow structure in encapsulate_subgraphs_pass and encapsulate_tpu_computations_pass, in order to detect errors earlier.
PiperOrigin-RevId: 200735435
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc16
-rw-r--r--tensorflow/compiler/tf2xla/BUILD27
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc15
-rw-r--r--tensorflow/compiler/tf2xla/validate_control_flow.cc84
-rw-r--r--tensorflow/compiler/tf2xla/validate_control_flow.h37
-rw-r--r--tensorflow/compiler/tf2xla/validate_control_flow_test.cc131
7 files changed, 296 insertions, 15 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 8c74014614..a92218b129 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -321,6 +321,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:validate_control_flow",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index 9448b8ebde..b78c30c215 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/validate_control_flow.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
@@ -1504,6 +1505,11 @@ Status Encapsulator::SplitIntoSubgraphs() {
for (auto& entry : subgraphs_) {
Subgraph& subgraph = entry.second;
FixupSourceAndSinkEdges(subgraph.GetGraph());
+ // Verify that the graph has well-formed control flow structure to be
+ // functionalized.
+ std::vector<ControlFlowInfo> dummy;
+ TF_RETURN_IF_ERROR(
+ BuildAndValidateControlFlowInfo(subgraph.GetGraph(), &dummy));
}
return s;
@@ -2519,10 +2525,12 @@ Status EncapsulateSubgraphsPass::Run(
return Status::OK();
};
- TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
- kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
- rewrite_subgraph,
- /*reuse_existing_functions=*/false, &graph_out, library));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EncapsulateSubgraphsInFunctions(
+ kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
+ rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out,
+ library),
+ "EncapsulateSubgraphsPass failed");
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index cd57452302..6b73cee2a8 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -407,11 +407,38 @@ cc_library(
)
cc_library(
+ name = "validate_control_flow",
+ srcs = ["validate_control_flow.cc"],
+ hdrs = ["validate_control_flow.h"],
+ deps = [
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "validate_control_flow_test",
+ srcs = ["validate_control_flow_test.cc"],
+ deps = [
+ ":validate_control_flow",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:while_loop",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "functionalize_control_flow",
srcs = ["functionalize_control_flow.cc"],
hdrs = ["functionalize_control_flow.h"],
deps = [
":tf2xla_util",
+ ":validate_control_flow",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 1438f6b48c..b9ed44e354 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/validate_control_flow.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -1439,7 +1440,9 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
// invariant.
std::vector<ControlFlowInfo> cf_info;
std::vector<string> unreachable_nodes;
- TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ BuildAndValidateControlFlowInfo(graph, &cf_info, &unreachable_nodes),
+ "FunctionalizeControlFlow failed");
if (!unreachable_nodes.empty()) {
return errors::InvalidArgument(
"The following nodes are unreachable from the source in the graph: ",
@@ -1464,10 +1467,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
frame.parent = parent;
frame.name = cf.frame_name;
++parent->num_children;
- } else if (frame.parent != parent) {
- return errors::InvalidArgument("Mismatched parent frames for ",
- cf.frame->id(), ": ", parent->name, " vs ",
- frame.parent->name);
}
if (IsEnter(node)) {
@@ -1477,12 +1476,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
&arg.is_loop_invariant));
frame.args.push_back(arg);
} else if (IsLoopCond(node)) {
- if (frame.loop_cond) {
- return errors::InvalidArgument(
- "Loop ", cf.frame_name,
- " has more than one LoopCond node: ", node->name(), " and ",
- frame.loop_cond->name());
- }
frame.loop_cond = node;
}
frame.nodes.insert(node);
diff --git a/tensorflow/compiler/tf2xla/validate_control_flow.cc b/tensorflow/compiler/tf2xla/validate_control_flow.cc
new file mode 100644
index 0000000000..1b3be4cfa4
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/validate_control_flow.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/validate_control_flow.h"
+
+#include <vector>
+
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+// Information about a loop frame structure.
+struct Frame {
+ string name;
+
+ // Pointer to the parent frame. The root frame has a pointer to itself.
+ Frame* parent = nullptr;
+
+ // The loop condition of the loop. There should be exactly one loop condition
+ // in every loop.
+ const Node* loop_cond = nullptr;
+};
+
+// Verify that the ControlFlowInfo of the graph has valid loop structure.
+Status ValidateControlFlowInfo(const Graph* graph,
+ const std::vector<ControlFlowInfo>& cf_info) {
+ std::unordered_map<string, Frame> frames;
+ for (const Node* node : graph->op_nodes()) {
+ const ControlFlowInfo& cf = cf_info[node->id()];
+ if (!cf.frame || !cf.parent_frame) {
+ // Skip nodes unreachable from the source node. They might be pruned
+ // later.
+ continue;
+ }
+
+ Frame& frame = frames[cf.frame_name];
+ Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
+ if (frame.parent == nullptr) {
+ frame.parent = parent;
+ frame.name = cf.frame_name;
+ } else if (frame.parent != parent) {
+ return errors::InvalidArgument(
+ "Invalid loop structure: Mismatched parent frames for \"",
+ cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name,
+ "\". This is an internal bug, please file a bug report with "
+ "instructions on how to reproduce the error.");
+ }
+ if (IsLoopCond(node)) {
+ if (frame.loop_cond) {
+ return errors::InvalidArgument(
+ "Invalid loop structure: Loop \"", cf.frame_name,
+ "\" has more than one LoopCond node: \"", node->name(), "\" and \"",
+ frame.loop_cond->name(),
+ "\". This is an internal bug, please file a bug report with "
+ "instructions on how to reproduce the error.");
+ }
+ frame.loop_cond = node;
+ }
+ }
+ return Status::OK();
+}
+} // namespace
+
+Status BuildAndValidateControlFlowInfo(const Graph* graph,
+ std::vector<ControlFlowInfo>* info,
+ std::vector<string>* unreachable_nodes) {
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, info, unreachable_nodes));
+ return ValidateControlFlowInfo(graph, *info);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/validate_control_flow.h b/tensorflow/compiler/tf2xla/validate_control_flow.h
new file mode 100644
index 0000000000..74159dc929
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/validate_control_flow.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_
+#define TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_
+
+#include <vector>
+
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Populate the control flow frame info of each node in the graph. Verify that
+// the graph has well-formed control flow strcuture that can be functionalized.
+// If unreachable_nodes is not nullptr, append to it the names of nodes
+// unreachable from the source node.
+Status BuildAndValidateControlFlowInfo(
+ const Graph* graph, std::vector<ControlFlowInfo>* info,
+ std::vector<string>* unreachable_nodes = nullptr);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_
diff --git a/tensorflow/compiler/tf2xla/validate_control_flow_test.cc b/tensorflow/compiler/tf2xla/validate_control_flow_test.cc
new file mode 100644
index 0000000000..74c9f4b86c
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/validate_control_flow_test.cc
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/validate_control_flow.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+Status LessThanTenCond(const Scope& scope, const std::vector<Output>& inputs,
+ Output* output) {
+ *output = ops::Less(scope, inputs[0], 10);
+ return scope.status();
+}
+
+Status AddOneBody(const Scope& scope, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::AddN(scope, {inputs[0], 1}));
+ return scope.status();
+}
+
+Status NestedLoopBody(const Scope& scope, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ return ops::BuildWhileLoop(scope.NewSubScope("inner"), inputs,
+ LessThanTenCond, AddOneBody, "inner_loop",
+ outputs);
+}
+
+TEST(ValidateControlFlowTest, InputsFromDifferentFrames) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ std::vector<Output> inputs;
+ inputs.push_back(ops::Placeholder(scope, DT_INT32));
+ std::vector<Output> outputs;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("outer"), inputs,
+ LessThanTenCond, NestedLoopBody,
+ "outer_loop", &outputs));
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ // {inner/Enter', 'outer/Switch'} --> 'inner/Merge'. 'inner/Enter' is in frame
+ // 'inner_loop'. 'outer/Switch' is in frame 'outer_loop'.
+ std::vector<ControlFlowInfo> info;
+ Status status = BuildAndValidateControlFlowInfo(graph.get(), &info);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "has inputs from different frames"))
+ << status.error_message();
+}
+
+TEST(ValidateControlFlowTest, MismatchedParentFrames) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ std::vector<Output> inputs;
+ inputs.push_back(ops::Placeholder(scope, DT_INT32));
+ std::vector<Output> outputs;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
+ "test_loop", &outputs));
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ Node* enter_1 = nullptr;
+ for (Node* node : graph->op_nodes()) {
+ if (IsEnter(node)) {
+ enter_1 = node;
+ }
+ }
+ ASSERT_TRUE(enter_1 != nullptr);
+
+ NodeDef enter;
+ enter.set_name("Enter2");
+ enter.set_op("Enter");
+ (*enter.mutable_attr())["T"].set_type(DT_INT32);
+ (*enter.mutable_attr())["frame_name"].set_s("test_loop");
+ *enter.add_input() = "Enter";
+ Status status;
+ Node* enter_2 = graph->AddNode(enter, &status);
+ TF_ASSERT_OK(status);
+ graph->AddControlEdge(enter_1, enter_2);
+
+ // SOURCE("") --> Enter("test_loop") --> Enter2("test_loop")
+ // For node 'Enter', the parent frame of "test_loop" is empty.
+ // For node 'Enter2', the parent frame of "test_loop" is "test_loop".
+ std::vector<ControlFlowInfo> info;
+ status = BuildAndValidateControlFlowInfo(graph.get(), &info);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "Mismatched parent frames"))
+ << status.error_message();
+}
+
+TEST(ValidateControlFlowTest, TwoLoopCond) {
+ // Test that one frame has at most one LoopCond node. This is necessary for
+ // functionalize control flow.
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ std::vector<Output> inputs;
+ inputs.push_back(ops::Placeholder(scope, DT_INT32));
+ std::vector<Output> outputs;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
+ "test_loop", &outputs));
+ outputs.clear();
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("sub"), inputs,
+ LessThanTenCond, AddOneBody, "test_loop",
+ &outputs, false));
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ std::vector<ControlFlowInfo> info;
+ Status status = BuildAndValidateControlFlowInfo(graph.get(), &info);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "more than one LoopCond node"))
+ << status.error_message();
+}
+
+} // namespace
+} // namespace tensorflow