diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-09-07 18:47:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 18:53:18 -0700 |
commit | 4fd48f57cd1dcd960bea1757e1c59032db66b3d0 (patch) | |
tree | 49ef65ef479a2904ff08d5ed9bfb65bac08a64f4 /tensorflow/compiler/jit | |
parent | 3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (diff) |
Decluster some must-be-constant ops to reduce XLA recompilations
The CL is organized as follows:
- The main change is in jit/partially_decluster_pass.
- tf2xla/const_analysis now takes an "edge_filter" to facilitate use by
jit/partially_decluster_pass.
- tests/dense_layer_test.py was using the execution of ListDiff as what I
assume is a sanity check to see that the XLA cluster ran. With this CL the
ListDiff op gets declustered so we now check for "MatMult" for the sanity
check.
- Some tests were dropping TF_XLA_FLAGS; fixed them to not do so.
PiperOrigin-RevId: 212071118
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/compiler/jit/partially_decluster_pass.cc | 175 | ||||
-rw-r--r-- | tensorflow/compiler/jit/partially_decluster_pass.h | 31 | ||||
-rw-r--r-- | tensorflow/compiler/jit/partially_decluster_pass_test.cc | 133 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_cluster_util.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_cluster_util.h | 3 |
6 files changed, 302 insertions, 46 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index de7cd26d1d..a989f15a1c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -395,6 +395,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) @@ -480,6 +481,7 @@ tf_cc_test( ":common", ":compilation_passes", ":xla_cluster_util", + ":xla_gpu_device", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -496,6 +498,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 584c963f71..10fc9e85d9 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { return Status::OK(); } -} // namespace -Status PartiallyDeclusterPass::Run( - const GraphOptimizationPassOptions& options) { - // NB! In this pass we assume the only XLA-auto-clusterable operations that - // may have side effects are resource variable operations so we don't cluster - // those. The pass will have to be updated if this assumption becomes - // invalid. - - Graph* graph = options.graph->get(); +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } +// Clones nodes to outside their cluster to avoid device-to-host copies. For +// instance, converts this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been // visited before producers. std::vector<Node*> post_order; GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); + /*edge_filter=*/NotBackedge); gtl::FlatSet<Node*> nodes_to_partially_decluster; - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); if (VLOG_IS_ON(3)) { for (Node* n : post_order) { @@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run( } nodes_to_partially_decluster.clear(); - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); CHECK(nodes_to_partially_decluster.empty()); return Status::OK(); } + +bool IsIntraClusterEdge(const Edge& edge) { + absl::optional<absl::string_view> src_cluster_name = + GetXlaClusterForNode(*edge.src()); + absl::optional<absl::string_view> dst_cluster_name = + GetXlaClusterForNode(*edge.dst()); + return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; +} + +Status MustCompileNode(const Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *result = false; + } else { + *result = registration->requires_compilation; + } + + return Status::OK(); +} + +// Declusters nodes to reduce the number of times we think we need to recompile +// a TensorFlow graph. +// +// Abstractly, if we have a cluster of this form: +// +// x0 = arg0 +// x1 = arg1 +// ... +// shape = f(x0, x1, ...) +// result = Reshape(input=<something>, new_shape=shape) +// +// then pulling `f` out of the cluster may reduce the number of compilations and +// will never increase the number of compilations. +// +// We may reduce the number of compilations if f is many to one. For instance +// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different +// compilations if f is in the cluster but only one compilation if f is outside +// the cluster. +// +// Declustering f will increase the number of compilations only if f is a +// one-to-many "function" i.e. isn't a function at all. RNG is one possible +// example, depending on how we look at it. But we never create clusters where +// such f's would be marked as must-be-constant. +// +// We assume here that the extra repeated (repeated compared to a clustered f +// where it will always be constant folded) host-side computation of f does not +// regress performance in any significant manner. We will have to revisit this +// algorith with a more complex cost model if this assumption turns out to be +// incorrect. +Status DeclusterNodesToReduceRecompilations(Graph* graph) { + std::vector<bool> compile_time_const_nodes(graph->num_node_ids()); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); + + std::vector<Node*> rpo; + GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + for (Node* n : rpo) { + if (!compile_time_const_nodes[n->id()]) { + continue; + } + + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + bool node_on_cluster_edge = + absl::c_all_of(n->in_edges(), [&](const Edge* e) { + absl::optional<absl::string_view> incoming_cluster = + GetXlaClusterForNode(*e->src()); + return !incoming_cluster || *incoming_cluster != cluster_name; + }); + + // We don't want to decluster F in a graph like + // + // Input -> OP -> Shape -> F -> Reshape + // + // Doing so will break up the cluster. Even if we were okay with breaking + // up the cluster we will at least have to relabel the two clusters to have + // different cluster names. + // + // We may want to revisit this in the future: we may have cases where OP is + // a small computation that does not benefit from XLA while XLA can optimize + // everything that follows the Reshape. In these cases it may be wise to + // remove Input, OP, Shape and F from the cluster, if F is a many-to-one + // function. + // + // Note that we do do the right thing for graphs like: + // + // Input -> F0 -> F1 -> Reshape + // + // Since we iterate in RPO, we'll first encounter F0, decluster it, then + // encounter F1, decluster it and so on. + if (node_on_cluster_edge) { + bool must_compile_node; + TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node)); + if (!must_compile_node) { + VLOG(3) << "Declustering must-be-constant node " << n->name(); + RemoveFromXlaCluster(n); + } + } + } + + return Status::OK(); +} + +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); + TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + + return Status::OK(); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h index 6949b5028e..cfc4ddb563 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.h +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -20,34 +20,11 @@ limitations under the License. namespace tensorflow { -// Clones nodes from within a cluster to outside the cluster if profitable. +// Clones or moves nodes from within a cluster to outside the cluster if +// profitable. There are two reasons why we do this: // -// Today this only clones to avoid device-to-host copies, but in the future we -// may consider other reasons to clone. For instance, we convert this: -// -// ..... -// | -// v -// A_Clustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// to: -// -// ..... -// | | -// | +-------------+ -// | | -// v v -// A_Clustered A_Unclustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// where the ===> arrow has a hostmem source and destination and would entail a -// device to host copy if the source and destination were not in the same XLA -// cluster. +// - Reducing device-to-host copies. +// - Reducing the number of XLA recompilations. class PartiallyDeclusterPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index f61a955c22..35872daa65 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/memory/memory.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -82,7 +84,9 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) { // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } } GraphOptimizationPassOptions opt_options; @@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) { return pass.Run(opt_options); } -const Node* FindNodeByName(const Graph& graph, const string& name) { - for (const Node* node : graph.nodes()) { +Node* FindNodeByName(const Graph& graph, const string& name) { + for (Node* node : graph.nodes()) { if (node->name() == name) { return node; } @@ -279,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { "ClusteredProducer0/declustered"); EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); } + +void AddToCluster(absl::Span<Node* const> nodes, + absl::string_view cluster_name) { + for (Node* n : nodes) { + n->AddAttr(kXlaClusterAttr, string(cluster_name)); + } +} + +TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + auto graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt); +} + +TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT, + ops::Placeholder::Attrs{}); + Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b); + Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul); + + Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()}, + "cluster_0"); + + std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + +TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({reshape.node()}, "cluster_0"); + AddToCluster({shape.node()}, "cluster_1"); + + std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + // This is needed to register the XLA_GPU device. + std::vector<Device*> devices; + TF_ASSERT_OK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); + + // Scope::ToGraph loses the assigned device name since it goes through + // GraphDef/NodeDef which does not have a field for the assigned device name. + Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + n->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:XLA_GPU:0"); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); + + for (Device* d : devices) { + delete d; + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 03380e9406..f85121ca27 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -210,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) { node_def->mutable_attr()->erase(kXlaClusterAttr); } +void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } + Status AdjustCycleDetectionGraphForResourceOps( const Graph* graph, const FunctionLibraryDefinition* flib_def, const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index debd9038c7..94c96ac7c5 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -53,6 +53,9 @@ absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node); // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). void RemoveFromXlaCluster(NodeDef* node_def); +// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(Node* node); + // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); |