aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-18 15:42:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 15:46:41 -0700
commite1a32c98210f8ebba42a0397259d948e1433c09e (patch)
tree0a62289fe29cf2c0f481bdca90c03811a47caadf /tensorflow/compiler/jit
parent6c8f6920e8bad10429ac0b88abbe0ace5a5e9a72 (diff)
"Isolate" must-be-constant side effecting operations
I first tried to fix this issue in cr/209996730 but didn't quite fix the problem for for XLA_* devices. A node assigned to an XLA_* device must be compiled so the cr/209996730 fix of simply not compiling the nodes doesn't generalize to XLA_* devices. Instead we now "isolate" these nodes, only putting them in a trivial one-node cluster. For non-XLA devices even this trivial cluster is ignored because of flags->tf_xla_min_cluster_size. I was initially considering a more principled data-flow-analysis based solution but then decided the upfront work isn't worth it until I see a clear motivating example. PiperOrigin-RevId: 213531437
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc73
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc66
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc21
3 files changed, 147 insertions, 13 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index e6cc6e52ae..1eaedbfbfb 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -365,10 +365,13 @@ bool IsXlaFusable(const NodeDef& node) {
return elementwise_ops->count(node.op()) > 0;
}
+// Nodes that XLA can compile are put in `candidates`. Nodes put in
+// `isolated_nodes` must either be unclustered or be put in trivial single-node
+// clusters.
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
- OrderedNodeSet* candidates) {
+ OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -411,6 +414,8 @@ Status FindCompilationCandidates(
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
+ VLOG(4) << "Device type for " << node->name() << ": "
+ << device_type.type_string();
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
// is_compilable_fn has already logged the reason if it returned false.
@@ -439,19 +444,56 @@ Status FindCompilationCandidates(
<< node->type_string();
continue;
}
- if (compile_time_const_nodes[node->id()] &&
- !registration->requires_compilation) {
+ if (compile_time_const_nodes[node->id()]) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
if (op_def->is_stateful()) {
- // We need to be able to constant fold the nodes in
- // compile_time_const_nodes given constant inputs (required by XLA) and
- // therefore can't auto-cluster stateful ops since these can never be
- // constant folded.
- VLOG(2) << "Rejecting " << node->name()
- << ": must-be-constant stateful op";
- continue;
+ // It is easiest to demonstrate the problem we're trying to solve with
+ // an example. Say we have this graph:
+ //
+ // shape = RandomUniformInt();
+ // reshape = Reshape(input, shape)
+ //
+ // Both RandomUniformInt and Reshape are compilable by XLA so, absent
+ // any other reason, we will try to put both shape and reshape in the
+ // same cluster. However, since XLA only supports statically shaped
+ // values, it will expect to be able to constant fold `shape` to get a
+ // static shape for `reshape`. This is a problem because side-effecting
+ // ops like RandomUniformInt() cannot be constant folded. We fix this
+ // by putting `shape` and `reshape` in different clusters, which results
+ // in us recompiling `reshape`'s cluster for every new value of `shape`,
+ // making `reshape` statically sized within each compilation. We
+ // simplify the solution even further by disallowing operations like
+ // `shape` from being part of *any* non-trivial cluster. They're either
+ // not compiled by XLA altogether or, if assigned to an XLA_* device
+ // with "must compile" semantics, compiled into a trivial single-op
+ // cluster. This approach leaves some room for improvement, and we can
+ // consider implementing a more aggressive data-flow-analysis based
+ // solution in the future if needed.
+ //
+ // One ugly problem we have to contend with: certain sets of ops *have*
+ // to be in the same cluster because values flowing between them have
+ // types that can't be live-in or live-out of a cluster. These ops are:
+ //
+ // - TensorArray ops operating on the same TensorArray instance.
+ // - Stack ops operating on the same Stack instance.
+ //
+ // To work around this we avoid isolating these specific ops. Because
+ // of this concession it is unsound to auto-cluster them because then
+ // we'd create clusters we could not compile (because we can't constant
+ // fold, say, a TensorArrayRead or a StackPopV2). But we don't
+ // auto-cluster these operations today so we're good for now.
+ const XlaResourceOpInfo* op_info =
+ GetResourceOpInfoForOp(node->type_string());
+ bool is_tensor_array_or_stack_op =
+ op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+ if (!is_tensor_array_or_stack_op) {
+ VLOG(2) << "Isolating " << node->name()
+ << ": must-be-constant stateful op";
+ isolated_nodes->insert(node);
+ // Keep going and execute all the other checks.
+ }
}
}
// We don't auto-cluster functional control flow nodes containing resource
@@ -807,11 +849,12 @@ Status MarkForCompilationPass::RunImpl(
Graph* graph = options.graph->get();
OrderedNodeSet compilation_candidates;
+ gtl::FlatSet<Node*> isolated_nodes;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
: Env::Default(),
- is_compilable_fn, &compilation_candidates));
+ is_compilable_fn, &compilation_candidates, &isolated_nodes));
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
@@ -856,6 +899,11 @@ Status MarkForCompilationPass::RunImpl(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
+
+ if (isolated_nodes.count(node_from)) {
+ continue;
+ }
+
string from_scope;
string to_scope;
for (int to : cycles.Successors(from)) {
@@ -873,6 +921,9 @@ Status MarkForCompilationPass::RunImpl(
node_to->assigned_device_name()) {
continue;
}
+ if (isolated_nodes.count(node_to)) {
+ continue;
+ }
// Look for an _XlaScope on both nodes. If both nodes have a
// scope and the scopes do not match, do not cluster along this
// edge. This restriction is overridden if the global_jit_level is ON. If
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index c59770a4c8..4f9145b479 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -894,5 +894,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) {
EXPECT_EQ(clusters["fn_call"], "");
}
+TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape =
+ ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
+ ops::Const(root.WithOpName("test/minval"), 1),
+ ops::Const(root.WithOpName("test/maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/shape_rng"], "");
+ EXPECT_NE(clusters["test/reshape"], "");
+ EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
+}
+
+TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
+ DT_INT32);
+ Output zero = ops::Const(root.WithOpName("test/zero"), 0);
+ ops::TensorArrayWrite tensor_array_write(
+ root.WithOpName("test/write"), tensor_array.handle, zero,
+ ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
+ Output tensor_array_read =
+ ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
+ zero, tensor_array_write.flow_out, DT_INT32);
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"),
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
+ tensor_array_read);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/read"], "");
+ EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index 65669877f7..d56d0f8ccf 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,18 +14,35 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
SessionOptions* session_options) {
- // Assign all nodes to the CPU device.
+ // Assign all unassigned nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
}
+ // Call AddDevices to register the XLA devices.
+ //
+ // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
+ // make this more direct, but probably not worth it solely for this test.
+ std::vector<Device*> devices;
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
+
+ auto delete_devices = gtl::MakeCleanup([&] {
+ for (Device* d : devices) {
+ delete d;
+ }
+ });
+
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.session_options = session_options;