aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-07 18:47:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 18:53:18 -0700
commit4fd48f57cd1dcd960bea1757e1c59032db66b3d0 (patch)
tree49ef65ef479a2904ff08d5ed9bfb65bac08a64f4 /tensorflow/compiler/jit
parent3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (diff)
Decluster some must-be-constant ops to reduce XLA recompilations
The CL is organized as follows: - The main change is in jit/partially_decluster_pass. - tf2xla/const_analysis now takes an "edge_filter" to facilitate use by jit/partially_decluster_pass. - tests/dense_layer_test.py was using the execution of ListDiff as what I assume is a sanity check to see that the XLA cluster ran. With this CL the ListDiff op gets declustered so we now check for "MatMult" for the sanity check. - Some tests were dropping TF_XLA_FLAGS; fixed them to not do so. PiperOrigin-RevId: 212071118
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD4
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc175
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.h31
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc133
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc2
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h3
6 files changed, 302 insertions, 46 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index de7cd26d1d..a989f15a1c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -395,6 +395,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)
@@ -480,6 +481,7 @@ tf_cc_test(
":common",
":compilation_passes",
":xla_cluster_util",
+ ":xla_gpu_device",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
@@ -496,6 +498,8 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 584c963f71..10fc9e85d9 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
return Status::OK();
}
-} // namespace
-Status PartiallyDeclusterPass::Run(
- const GraphOptimizationPassOptions& options) {
- // NB! In this pass we assume the only XLA-auto-clusterable operations that
- // may have side effects are resource variable operations so we don't cluster
- // those. The pass will have to be updated if this assumption becomes
- // invalid.
-
- Graph* graph = options.graph->get();
+bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
+// Clones nodes to outside their cluster to avoid device-to-host copies. For
+// instance, converts this:
+//
+// .....
+// |
+// v
+// A_Clustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// to:
+//
+// .....
+// | |
+// | +-------------+
+// | |
+// v v
+// A_Clustered A_Unclustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// where the ===> arrow has a hostmem source and destination and would entail a
+// device to host copy if the source and destination were not in the same XLA
+// cluster.
+Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
// When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been
// visited before producers.
std::vector<Node*> post_order;
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
- /*edge_filter=*/[](const Edge& edge) {
- return !edge.src()->IsNextIteration();
- });
+ /*edge_filter=*/NotBackedge);
gtl::FlatSet<Node*> nodes_to_partially_decluster;
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
if (VLOG_IS_ON(3)) {
for (Node* n : post_order) {
@@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run(
}
nodes_to_partially_decluster.clear();
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
CHECK(nodes_to_partially_decluster.empty());
return Status::OK();
}
+
+bool IsIntraClusterEdge(const Edge& edge) {
+ absl::optional<absl::string_view> src_cluster_name =
+ GetXlaClusterForNode(*edge.src());
+ absl::optional<absl::string_view> dst_cluster_name =
+ GetXlaClusterForNode(*edge.dst());
+ return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
+}
+
+Status MustCompileNode(const Node* n, bool* result) {
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
+
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
+ *result = false;
+ } else {
+ *result = registration->requires_compilation;
+ }
+
+ return Status::OK();
+}
+
+// Declusters nodes to reduce the number of times we think we need to recompile
+// a TensorFlow graph.
+//
+// Abstractly, if we have a cluster of this form:
+//
+// x0 = arg0
+// x1 = arg1
+// ...
+// shape = f(x0, x1, ...)
+// result = Reshape(input=<something>, new_shape=shape)
+//
+// then pulling `f` out of the cluster may reduce the number of compilations and
+// will never increase the number of compilations.
+//
+// We may reduce the number of compilations if f is many to one. For instance
+// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
+// compilations if f is in the cluster but only one compilation if f is outside
+// the cluster.
+//
+// Declustering f will increase the number of compilations only if f is a
+// one-to-many "function" i.e. isn't a function at all. RNG is one possible
+// example, depending on how we look at it. But we never create clusters where
+// such f's would be marked as must-be-constant.
+//
+// We assume here that the extra repeated (repeated compared to a clustered f
+// where it will always be constant folded) host-side computation of f does not
+// regress performance in any significant manner. We will have to revisit this
+// algorith with a more complex cost model if this assumption turns out to be
+// incorrect.
+Status DeclusterNodesToReduceRecompilations(Graph* graph) {
+ std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
+
+ std::vector<Node*> rpo;
+ GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/NotBackedge);
+ for (Node* n : rpo) {
+ if (!compile_time_const_nodes[n->id()]) {
+ continue;
+ }
+
+ absl::string_view cluster_name = *GetXlaClusterForNode(*n);
+ bool node_on_cluster_edge =
+ absl::c_all_of(n->in_edges(), [&](const Edge* e) {
+ absl::optional<absl::string_view> incoming_cluster =
+ GetXlaClusterForNode(*e->src());
+ return !incoming_cluster || *incoming_cluster != cluster_name;
+ });
+
+ // We don't want to decluster F in a graph like
+ //
+ // Input -> OP -> Shape -> F -> Reshape
+ //
+ // Doing so will break up the cluster. Even if we were okay with breaking
+ // up the cluster we will at least have to relabel the two clusters to have
+ // different cluster names.
+ //
+ // We may want to revisit this in the future: we may have cases where OP is
+ // a small computation that does not benefit from XLA while XLA can optimize
+ // everything that follows the Reshape. In these cases it may be wise to
+ // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
+ // function.
+ //
+ // Note that we do do the right thing for graphs like:
+ //
+ // Input -> F0 -> F1 -> Reshape
+ //
+ // Since we iterate in RPO, we'll first encounter F0, decluster it, then
+ // encounter F1, decluster it and so on.
+ if (node_on_cluster_edge) {
+ bool must_compile_node;
+ TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
+ if (!must_compile_node) {
+ VLOG(3) << "Declustering must-be-constant node " << n->name();
+ RemoveFromXlaCluster(n);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status PartiallyDeclusterPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ // NB! In this pass we assume the only XLA-auto-clusterable operations that
+ // may have side effects are resource variable operations so we don't cluster
+ // those. The pass will have to be updated if this assumption becomes
+ // invalid.
+
+ Graph* graph = options.graph->get();
+
+ TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
+ TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
+
+ return Status::OK();
+}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h
index 6949b5028e..cfc4ddb563 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.h
+++ b/tensorflow/compiler/jit/partially_decluster_pass.h
@@ -20,34 +20,11 @@ limitations under the License.
namespace tensorflow {
-// Clones nodes from within a cluster to outside the cluster if profitable.
+// Clones or moves nodes from within a cluster to outside the cluster if
+// profitable. There are two reasons why we do this:
//
-// Today this only clones to avoid device-to-host copies, but in the future we
-// may consider other reasons to clone. For instance, we convert this:
-//
-// .....
-// |
-// v
-// A_Clustered ====> C_Unclustered
-// |
-// v
-// B_Clustered
-//
-// to:
-//
-// .....
-// | |
-// | +-------------+
-// | |
-// v v
-// A_Clustered A_Unclustered ====> C_Unclustered
-// |
-// v
-// B_Clustered
-//
-// where the ===> arrow has a hostmem source and destination and would entail a
-// device to host copy if the source and destination were not in the same XLA
-// cluster.
+// - Reducing device-to-host copies.
+// - Reducing the number of XLA recompilations.
class PartiallyDeclusterPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index f61a955c22..35872daa65 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
@@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -82,7 +84,9 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
}
GraphOptimizationPassOptions opt_options;
@@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
return pass.Run(opt_options);
}
-const Node* FindNodeByName(const Graph& graph, const string& name) {
- for (const Node* node : graph.nodes()) {
+Node* FindNodeByName(const Graph& graph, const string& name) {
+ for (Node* node : graph.nodes()) {
if (node->name() == name) {
return node;
}
@@ -279,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
"ClusteredProducer0/declustered");
EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
}
+
+void AddToCluster(absl::Span<Node* const> nodes,
+ absl::string_view cluster_name) {
+ for (Node* n : nodes) {
+ n->AddAttr(kXlaClusterAttr, string(cluster_name));
+ }
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+ auto graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt);
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
+ ops::Placeholder::Attrs{});
+ Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
+ Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
+
+ Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
+ "cluster_0");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+}
+
+TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({reshape.node()}, "cluster_0");
+ AddToCluster({shape.node()}, "cluster_1");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+
+ // This is needed to register the XLA_GPU device.
+ std::vector<Device*> devices;
+ TF_ASSERT_OK(DeviceFactory::AddDevices(
+ SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
+
+ // Scope::ToGraph loses the assigned device name since it goes through
+ // GraphDef/NodeDef which does not have a field for the assigned device name.
+ Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+ n->set_assigned_device_name(
+ "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+
+ for (Device* d : devices) {
+ delete d;
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index 03380e9406..f85121ca27 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -210,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
node_def->mutable_attr()->erase(kXlaClusterAttr);
}
+void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
+
Status AdjustCycleDetectionGraphForResourceOps(
const Graph* graph, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index debd9038c7..94c96ac7c5 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -53,6 +53,9 @@ absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
void RemoveFromXlaCluster(NodeDef* node_def);
+// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute).
+void RemoveFromXlaCluster(Node* node);
+
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);