aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-06-02 14:06:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-02 14:09:39 -0700
commitd23f115d89ad6111674f53135d669cb2d2c086f0 (patch)
treec5029bc1d88ae189c2070eae0ab050feee5f6aa9 /tensorflow
parenta06e521204d7b5a2dd27de44efbab352ff918aa7 (diff)
Don't cluster Identity nodes that forward tensor refs
XLA cannot implement the forward-tensor-ref semantic -- there is no guaranteed aliasing between the input and output of the XLA cluster. PiperOrigin-RevId: 199005227
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc26
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc47
2 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 8e2ee0f1d7..07ee93d79e 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -46,6 +46,12 @@ const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
namespace {
+// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward
+// a ref tensor input to its output.
+static bool AlwaysForwardsRefInput(const Node& node) {
+ return node.IsIdentity();
+}
+
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
@@ -60,6 +66,26 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
return false;
}
}
+
+ // XLA does not offer guaranteed aliasing between the input and output of the
+ // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
+ // such nodes out of XLA clusters.
+ if (AlwaysForwardsRefInput(node)) {
+ for (const Edge* incoming_edge : node.in_edges()) {
+ if (incoming_edge->IsControlEdge()) {
+ continue;
+ }
+
+ Node* incoming_node = incoming_edge->src();
+ if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
+ VLOG(2) << "Not clustering " << node.def().ShortDebugString()
+ << " because of ref input " << incoming_node->name() << " "
+ << incoming_node->type_string();
+ return false;
+ }
+ }
+ }
+
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 703d8825d7..772c92d369 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) {
}
}
+TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output variable = ops::Variable(root.WithOpName("variable"),
+ PartialTensorShape{}, DT_FLOAT);
+ Output read = ops::Identity(root.WithOpName("read"), variable);
+ Output neg = ops::Negate(root.WithOpName("negate"), read);
+ Output add = ops::Add(root.WithOpName("add"), neg, neg);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+
+ ASSERT_FALSE(clusters.empty());
+ string cluster_name = clusters.begin()->second;
+
+ std::unordered_map<string, string> expected_clusters(
+ {{"negate", cluster_name}, {"add", cluster_name}});
+ EXPECT_EQ(clusters, expected_clusters);
+}
+
+TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output variable = ops::Variable(root.WithOpName("variable"),
+ PartialTensorShape{}, DT_FLOAT);
+ Output read = ops::Identity(root.WithOpName("read"), variable);
+ Output neg = ops::Negate(root.WithOpName("negate"), read);
+ Output identity = ops::Negate(root.WithOpName("identity"), neg);
+ Output add = ops::Add(root.WithOpName("add"), identity, neg);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+
+ ASSERT_FALSE(clusters.empty());
+ string cluster_name = clusters.begin()->second;
+
+ std::unordered_map<string, string> expected_clusters(
+ {{"negate", cluster_name},
+ {"identity", cluster_name},
+ {"add", cluster_name}});
+ EXPECT_EQ(clusters, expected_clusters);
+}
+
} // namespace
} // namespace tensorflow