From c9a2034f93981e17eef5f96fbd2894202b8fc2c1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 15 Jun 2018 10:25:09 -0700 Subject: [TF:XLA] Validate the control flow structure in encapsulate_subgraphs_pass and encapsulate_tpu_computations_pass, in order to detect errors earlier. PiperOrigin-RevId: 200735435 --- tensorflow/compiler/jit/BUILD | 1 + .../compiler/jit/encapsulate_subgraphs_pass.cc | 16 ++- tensorflow/compiler/tf2xla/BUILD | 27 +++++ .../compiler/tf2xla/functionalize_control_flow.cc | 15 +-- .../compiler/tf2xla/validate_control_flow.cc | 84 +++++++++++++ tensorflow/compiler/tf2xla/validate_control_flow.h | 37 ++++++ .../compiler/tf2xla/validate_control_flow_test.cc | 131 +++++++++++++++++++++ 7 files changed, 296 insertions(+), 15 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/validate_control_flow.cc create mode 100644 tensorflow/compiler/tf2xla/validate_control_flow.h create mode 100644 tensorflow/compiler/tf2xla/validate_control_flow_test.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 8c74014614..a92218b129 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -321,6 +321,7 @@ cc_library( "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:validate_control_flow", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 9448b8ebde..b78c30c215 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -1504,6 +1505,11 @@ Status Encapsulator::SplitIntoSubgraphs() { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); + // Verify that the graph has well-formed control flow structure to be + // functionalized. + std::vector dummy; + TF_RETURN_IF_ERROR( + BuildAndValidateControlFlowInfo(subgraph.GetGraph(), &dummy)); } return s; @@ -2519,10 +2525,12 @@ Status EncapsulateSubgraphsPass::Run( return Status::OK(); }; - TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, - /*reuse_existing_functions=*/false, &graph_out, library)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, + library), + "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cd57452302..6b73cee2a8 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -406,12 +406,39 @@ cc_library( ], ) +cc_library( + name = "validate_control_flow", + srcs = ["validate_control_flow.cc"], + hdrs = ["validate_control_flow.h"], + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "validate_control_flow_test", + srcs = ["validate_control_flow_test.cc"], + deps = [ + ":validate_control_flow", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:while_loop", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "functionalize_control_flow", srcs = ["functionalize_control_flow.cc"], hdrs = ["functionalize_control_flow.h"], deps = [ ":tf2xla_util", + ":validate_control_flow", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1438f6b48c..b9ed44e354 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" @@ -1439,7 +1440,9 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, // invariant. std::vector cf_info; std::vector unreachable_nodes; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + BuildAndValidateControlFlowInfo(graph, &cf_info, &unreachable_nodes), + "FunctionalizeControlFlow failed"); if (!unreachable_nodes.empty()) { return errors::InvalidArgument( "The following nodes are unreachable from the source in the graph: ", @@ -1464,10 +1467,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, frame.parent = parent; frame.name = cf.frame_name; ++parent->num_children; - } else if (frame.parent != parent) { - return errors::InvalidArgument("Mismatched parent frames for ", - cf.frame->id(), ": ", parent->name, " vs ", - frame.parent->name); } if (IsEnter(node)) { @@ -1477,12 +1476,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, &arg.is_loop_invariant)); frame.args.push_back(arg); } else if (IsLoopCond(node)) { - if (frame.loop_cond) { - return errors::InvalidArgument( - "Loop ", cf.frame_name, - " has more than one LoopCond node: ", node->name(), " and ", - frame.loop_cond->name()); - } frame.loop_cond = node; } frame.nodes.insert(node); diff --git a/tensorflow/compiler/tf2xla/validate_control_flow.cc b/tensorflow/compiler/tf2xla/validate_control_flow.cc new file mode 100644 index 0000000000..1b3be4cfa4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/validate_control_flow.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" + +#include + +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { +// Information about a loop frame structure. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + const Node* loop_cond = nullptr; +}; + +// Verify that the ControlFlowInfo of the graph has valid loop structure. +Status ValidateControlFlowInfo(const Graph* graph, + const std::vector& cf_info) { + std::unordered_map frames; + for (const Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + if (!cf.frame || !cf.parent_frame) { + // Skip nodes unreachable from the source node. They might be pruned + // later. + continue; + } + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + } else if (frame.parent != parent) { + return errors::InvalidArgument( + "Invalid loop structure: Mismatched parent frames for \"", + cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name, + "\". This is an internal bug, please file a bug report with " + "instructions on how to reproduce the error."); + } + if (IsLoopCond(node)) { + if (frame.loop_cond) { + return errors::InvalidArgument( + "Invalid loop structure: Loop \"", cf.frame_name, + "\" has more than one LoopCond node: \"", node->name(), "\" and \"", + frame.loop_cond->name(), + "\". This is an internal bug, please file a bug report with " + "instructions on how to reproduce the error."); + } + frame.loop_cond = node; + } + } + return Status::OK(); +} +} // namespace + +Status BuildAndValidateControlFlowInfo(const Graph* graph, + std::vector* info, + std::vector* unreachable_nodes) { + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, info, unreachable_nodes)); + return ValidateControlFlowInfo(graph, *info); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/validate_control_flow.h b/tensorflow/compiler/tf2xla/validate_control_flow.h new file mode 100644 index 0000000000..74159dc929 --- /dev/null +++ b/tensorflow/compiler/tf2xla/validate_control_flow.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_ +#define TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_ + +#include + +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Populate the control flow frame info of each node in the graph. Verify that +// the graph has well-formed control flow strcuture that can be functionalized. +// If unreachable_nodes is not nullptr, append to it the names of nodes +// unreachable from the source node. +Status BuildAndValidateControlFlowInfo( + const Graph* graph, std::vector* info, + std::vector* unreachable_nodes = nullptr); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/validate_control_flow_test.cc b/tensorflow/compiler/tf2xla/validate_control_flow_test.cc new file mode 100644 index 0000000000..74c9f4b86c --- /dev/null +++ b/tensorflow/compiler/tf2xla/validate_control_flow_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" + +#include +#include + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +Status LessThanTenCond(const Scope& scope, const std::vector& inputs, + Output* output) { + *output = ops::Less(scope, inputs[0], 10); + return scope.status(); +} + +Status AddOneBody(const Scope& scope, const std::vector& inputs, + std::vector* outputs) { + outputs->push_back(ops::AddN(scope, {inputs[0], 1})); + return scope.status(); +} + +Status NestedLoopBody(const Scope& scope, const std::vector& inputs, + std::vector* outputs) { + return ops::BuildWhileLoop(scope.NewSubScope("inner"), inputs, + LessThanTenCond, AddOneBody, "inner_loop", + outputs); +} + +TEST(ValidateControlFlowTest, InputsFromDifferentFrames) { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::vector inputs; + inputs.push_back(ops::Placeholder(scope, DT_INT32)); + std::vector outputs; + TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("outer"), inputs, + LessThanTenCond, NestedLoopBody, + "outer_loop", &outputs)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + // {inner/Enter', 'outer/Switch'} --> 'inner/Merge'. 'inner/Enter' is in frame + // 'inner_loop'. 'outer/Switch' is in frame 'outer_loop'. + std::vector info; + Status status = BuildAndValidateControlFlowInfo(graph.get(), &info); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "has inputs from different frames")) + << status.error_message(); +} + +TEST(ValidateControlFlowTest, MismatchedParentFrames) { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::vector inputs; + inputs.push_back(ops::Placeholder(scope, DT_INT32)); + std::vector outputs; + TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody, + "test_loop", &outputs)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + Node* enter_1 = nullptr; + for (Node* node : graph->op_nodes()) { + if (IsEnter(node)) { + enter_1 = node; + } + } + ASSERT_TRUE(enter_1 != nullptr); + + NodeDef enter; + enter.set_name("Enter2"); + enter.set_op("Enter"); + (*enter.mutable_attr())["T"].set_type(DT_INT32); + (*enter.mutable_attr())["frame_name"].set_s("test_loop"); + *enter.add_input() = "Enter"; + Status status; + Node* enter_2 = graph->AddNode(enter, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(enter_1, enter_2); + + // SOURCE("") --> Enter("test_loop") --> Enter2("test_loop") + // For node 'Enter', the parent frame of "test_loop" is empty. + // For node 'Enter2', the parent frame of "test_loop" is "test_loop". + std::vector info; + status = BuildAndValidateControlFlowInfo(graph.get(), &info); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "Mismatched parent frames")) + << status.error_message(); +} + +TEST(ValidateControlFlowTest, TwoLoopCond) { + // Test that one frame has at most one LoopCond node. This is necessary for + // functionalize control flow. + Scope scope = Scope::NewRootScope().ExitOnError(); + std::vector inputs; + inputs.push_back(ops::Placeholder(scope, DT_INT32)); + std::vector outputs; + TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody, + "test_loop", &outputs)); + outputs.clear(); + TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("sub"), inputs, + LessThanTenCond, AddOneBody, "test_loop", + &outputs, false)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + std::vector info; + Status status = BuildAndValidateControlFlowInfo(graph.get(), &info); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "more than one LoopCond node")) + << status.error_message(); +} + +} // namespace +} // namespace tensorflow -- cgit v1.2.3