aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass_test.cc')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 772c92d369..2c5f4fb774 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -680,5 +681,37 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
EXPECT_EQ(clusters, expected_clusters);
}
+TEST(XlaCompilationTest, ClusterControlTrigger) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
+ "sender", 0, "receiver");
+ Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
+ "sender", 0, "receiver");
+ Output const_a = ops::Const(root.WithOpName("const_a"), 42);
+
+ ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
+ ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
+ root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
+ root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
+ root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+
+ ASSERT_FALSE(clusters.empty());
+ string cluster_name = clusters.begin()->second;
+
+ // ctrl_trigger_a has inputs with mismatching deadness so it won't be
+ // clustered. ctrl_trigger_b is okay to cluster.
+ std::unordered_map<string, string> expected_clusters(
+ {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}});
+ EXPECT_EQ(clusters, expected_clusters);
+}
+
} // namespace
} // namespace tensorflow