aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/partially_decluster_pass.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-07 18:47:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 18:53:18 -0700
commit4fd48f57cd1dcd960bea1757e1c59032db66b3d0 (patch)
tree49ef65ef479a2904ff08d5ed9bfb65bac08a64f4 /tensorflow/compiler/jit/partially_decluster_pass.cc
parent3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (diff)
Decluster some must-be-constant ops to reduce XLA recompilations
The CL is organized as follows: - The main change is in jit/partially_decluster_pass. - tf2xla/const_analysis now takes an "edge_filter" to facilitate use by jit/partially_decluster_pass. - tests/dense_layer_test.py was using the execution of ListDiff as what I assume is a sanity check to see that the XLA cluster ran. With this CL the ListDiff op gets declustered so we now check for "MatMult" for the sanity check. - Some tests were dropping TF_XLA_FLAGS; fixed them to not do so. PiperOrigin-RevId: 212071118
Diffstat (limited to 'tensorflow/compiler/jit/partially_decluster_pass.cc')
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc175
1 files changed, 159 insertions, 16 deletions
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 584c963f71..10fc9e85d9 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
return Status::OK();
}
-} // namespace
-Status PartiallyDeclusterPass::Run(
- const GraphOptimizationPassOptions& options) {
- // NB! In this pass we assume the only XLA-auto-clusterable operations that
- // may have side effects are resource variable operations so we don't cluster
- // those. The pass will have to be updated if this assumption becomes
- // invalid.
-
- Graph* graph = options.graph->get();
+bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
+// Clones nodes to outside their cluster to avoid device-to-host copies. For
+// instance, converts this:
+//
+// .....
+// |
+// v
+// A_Clustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// to:
+//
+// .....
+// | |
+// | +-------------+
+// | |
+// v v
+// A_Clustered A_Unclustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// where the ===> arrow has a hostmem source and destination and would entail a
+// device to host copy if the source and destination were not in the same XLA
+// cluster.
+Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
// When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been
// visited before producers.
std::vector<Node*> post_order;
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
- /*edge_filter=*/[](const Edge& edge) {
- return !edge.src()->IsNextIteration();
- });
+ /*edge_filter=*/NotBackedge);
gtl::FlatSet<Node*> nodes_to_partially_decluster;
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
if (VLOG_IS_ON(3)) {
for (Node* n : post_order) {
@@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run(
}
nodes_to_partially_decluster.clear();
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
CHECK(nodes_to_partially_decluster.empty());
return Status::OK();
}
+
+bool IsIntraClusterEdge(const Edge& edge) {
+ absl::optional<absl::string_view> src_cluster_name =
+ GetXlaClusterForNode(*edge.src());
+ absl::optional<absl::string_view> dst_cluster_name =
+ GetXlaClusterForNode(*edge.dst());
+ return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
+}
+
+Status MustCompileNode(const Node* n, bool* result) {
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
+
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
+ *result = false;
+ } else {
+ *result = registration->requires_compilation;
+ }
+
+ return Status::OK();
+}
+
+// Declusters nodes to reduce the number of times we think we need to recompile
+// a TensorFlow graph.
+//
+// Abstractly, if we have a cluster of this form:
+//
+// x0 = arg0
+// x1 = arg1
+// ...
+// shape = f(x0, x1, ...)
+// result = Reshape(input=<something>, new_shape=shape)
+//
+// then pulling `f` out of the cluster may reduce the number of compilations and
+// will never increase the number of compilations.
+//
+// We may reduce the number of compilations if f is many to one. For instance
+// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
+// compilations if f is in the cluster but only one compilation if f is outside
+// the cluster.
+//
+// Declustering f will increase the number of compilations only if f is a
+// one-to-many "function" i.e. isn't a function at all. RNG is one possible
+// example, depending on how we look at it. But we never create clusters where
+// such f's would be marked as must-be-constant.
+//
+// We assume here that the extra repeated (repeated compared to a clustered f
+// where it will always be constant folded) host-side computation of f does not
+// regress performance in any significant manner. We will have to revisit this
+// algorith with a more complex cost model if this assumption turns out to be
+// incorrect.
+Status DeclusterNodesToReduceRecompilations(Graph* graph) {
+ std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
+
+ std::vector<Node*> rpo;
+ GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/NotBackedge);
+ for (Node* n : rpo) {
+ if (!compile_time_const_nodes[n->id()]) {
+ continue;
+ }
+
+ absl::string_view cluster_name = *GetXlaClusterForNode(*n);
+ bool node_on_cluster_edge =
+ absl::c_all_of(n->in_edges(), [&](const Edge* e) {
+ absl::optional<absl::string_view> incoming_cluster =
+ GetXlaClusterForNode(*e->src());
+ return !incoming_cluster || *incoming_cluster != cluster_name;
+ });
+
+ // We don't want to decluster F in a graph like
+ //
+ // Input -> OP -> Shape -> F -> Reshape
+ //
+ // Doing so will break up the cluster. Even if we were okay with breaking
+ // up the cluster we will at least have to relabel the two clusters to have
+ // different cluster names.
+ //
+ // We may want to revisit this in the future: we may have cases where OP is
+ // a small computation that does not benefit from XLA while XLA can optimize
+ // everything that follows the Reshape. In these cases it may be wise to
+ // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
+ // function.
+ //
+ // Note that we do do the right thing for graphs like:
+ //
+ // Input -> F0 -> F1 -> Reshape
+ //
+ // Since we iterate in RPO, we'll first encounter F0, decluster it, then
+ // encounter F1, decluster it and so on.
+ if (node_on_cluster_edge) {
+ bool must_compile_node;
+ TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
+ if (!must_compile_node) {
+ VLOG(3) << "Declustering must-be-constant node " << n->name();
+ RemoveFromXlaCluster(n);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status PartiallyDeclusterPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ // NB! In this pass we assume the only XLA-auto-clusterable operations that
+ // may have side effects are resource variable operations so we don't cluster
+ // those. The pass will have to be updated if this assumption becomes
+ // invalid.
+
+ Graph* graph = options.graph->get();
+
+ TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
+ TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
+
+ return Status::OK();
+}
} // namespace tensorflow