diff options
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass.cc | 26 | ||||
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | 47 |
2 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8e2ee0f1d7..07ee93d79e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -46,6 +46,12 @@ const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; namespace { +// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward +// a ref tensor input to its output. +static bool AlwaysForwardsRefInput(const Node& node) { + return node.IsIdentity(); +} + bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by @@ -60,6 +66,26 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return false; } } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Not clustering " << node.def().ShortDebugString() + << " because of ref input " << incoming_node->name() << " " + << incoming_node->type_string(); + return false; + } + } + } + return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 703d8825d7..772c92d369 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) { } } +TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output add = ops::Add(root.WithOpName("add"), neg, neg); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map<string, string> expected_clusters( + {{"negate", cluster_name}, {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + +TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output identity = ops::Negate(root.WithOpName("identity"), neg); + Output add = ops::Add(root.WithOpName("add"), identity, neg); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map<string, string> expected_clusters( + {{"negate", cluster_name}, + {"identity", cluster_name}, + {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow |