diff options
Diffstat (limited to 'tensorflow/compiler/jit/partially_decluster_pass.cc')
-rw-r--r-- | tensorflow/compiler/jit/partially_decluster_pass.cc | 175 |
1 files changed, 159 insertions, 16 deletions
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 584c963f71..10fc9e85d9 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { return Status::OK(); } -} // namespace -Status PartiallyDeclusterPass::Run( - const GraphOptimizationPassOptions& options) { - // NB! In this pass we assume the only XLA-auto-clusterable operations that - // may have side effects are resource variable operations so we don't cluster - // those. The pass will have to be updated if this assumption becomes - // invalid. - - Graph* graph = options.graph->get(); +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } +// Clones nodes to outside their cluster to avoid device-to-host copies. For +// instance, converts this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been // visited before producers. std::vector<Node*> post_order; GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); + /*edge_filter=*/NotBackedge); gtl::FlatSet<Node*> nodes_to_partially_decluster; - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); if (VLOG_IS_ON(3)) { for (Node* n : post_order) { @@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run( } nodes_to_partially_decluster.clear(); - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); CHECK(nodes_to_partially_decluster.empty()); return Status::OK(); } + +bool IsIntraClusterEdge(const Edge& edge) { + absl::optional<absl::string_view> src_cluster_name = + GetXlaClusterForNode(*edge.src()); + absl::optional<absl::string_view> dst_cluster_name = + GetXlaClusterForNode(*edge.dst()); + return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; +} + +Status MustCompileNode(const Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *result = false; + } else { + *result = registration->requires_compilation; + } + + return Status::OK(); +} + +// Declusters nodes to reduce the number of times we think we need to recompile +// a TensorFlow graph. +// +// Abstractly, if we have a cluster of this form: +// +// x0 = arg0 +// x1 = arg1 +// ... +// shape = f(x0, x1, ...) +// result = Reshape(input=<something>, new_shape=shape) +// +// then pulling `f` out of the cluster may reduce the number of compilations and +// will never increase the number of compilations. +// +// We may reduce the number of compilations if f is many to one. For instance +// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different +// compilations if f is in the cluster but only one compilation if f is outside +// the cluster. +// +// Declustering f will increase the number of compilations only if f is a +// one-to-many "function" i.e. isn't a function at all. RNG is one possible +// example, depending on how we look at it. But we never create clusters where +// such f's would be marked as must-be-constant. +// +// We assume here that the extra repeated (repeated compared to a clustered f +// where it will always be constant folded) host-side computation of f does not +// regress performance in any significant manner. We will have to revisit this +// algorith with a more complex cost model if this assumption turns out to be +// incorrect. +Status DeclusterNodesToReduceRecompilations(Graph* graph) { + std::vector<bool> compile_time_const_nodes(graph->num_node_ids()); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); + + std::vector<Node*> rpo; + GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + for (Node* n : rpo) { + if (!compile_time_const_nodes[n->id()]) { + continue; + } + + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + bool node_on_cluster_edge = + absl::c_all_of(n->in_edges(), [&](const Edge* e) { + absl::optional<absl::string_view> incoming_cluster = + GetXlaClusterForNode(*e->src()); + return !incoming_cluster || *incoming_cluster != cluster_name; + }); + + // We don't want to decluster F in a graph like + // + // Input -> OP -> Shape -> F -> Reshape + // + // Doing so will break up the cluster. Even if we were okay with breaking + // up the cluster we will at least have to relabel the two clusters to have + // different cluster names. + // + // We may want to revisit this in the future: we may have cases where OP is + // a small computation that does not benefit from XLA while XLA can optimize + // everything that follows the Reshape. In these cases it may be wise to + // remove Input, OP, Shape and F from the cluster, if F is a many-to-one + // function. + // + // Note that we do do the right thing for graphs like: + // + // Input -> F0 -> F1 -> Reshape + // + // Since we iterate in RPO, we'll first encounter F0, decluster it, then + // encounter F1, decluster it and so on. + if (node_on_cluster_edge) { + bool must_compile_node; + TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node)); + if (!must_compile_node) { + VLOG(3) << "Declustering must-be-constant node " << n->name(); + RemoveFromXlaCluster(n); + } + } + } + + return Status::OK(); +} + +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); + TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + + return Status::OK(); +} } // namespace tensorflow |