aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass_test.cc')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 703d8825d7..772c92d369 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) {
}
}
+TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output variable = ops::Variable(root.WithOpName("variable"),
+ PartialTensorShape{}, DT_FLOAT);
+ Output read = ops::Identity(root.WithOpName("read"), variable);
+ Output neg = ops::Negate(root.WithOpName("negate"), read);
+ Output add = ops::Add(root.WithOpName("add"), neg, neg);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+
+ ASSERT_FALSE(clusters.empty());
+ string cluster_name = clusters.begin()->second;
+
+ std::unordered_map<string, string> expected_clusters(
+ {{"negate", cluster_name}, {"add", cluster_name}});
+ EXPECT_EQ(clusters, expected_clusters);
+}
+
+TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output variable = ops::Variable(root.WithOpName("variable"),
+ PartialTensorShape{}, DT_FLOAT);
+ Output read = ops::Identity(root.WithOpName("read"), variable);
+ Output neg = ops::Negate(root.WithOpName("negate"), read);
+ Output identity = ops::Negate(root.WithOpName("identity"), neg);
+ Output add = ops::Add(root.WithOpName("add"), identity, neg);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+
+ ASSERT_FALSE(clusters.empty());
+ string cluster_name = clusters.begin()->second;
+
+ std::unordered_map<string, string> expected_clusters(
+ {{"negate", cluster_name},
+ {"identity", cluster_name},
+ {"add", cluster_name}});
+ EXPECT_EQ(clusters, expected_clusters);
+}
+
} // namespace
} // namespace tensorflow