aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/algorithm_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/algorithm_test.cc')
-rw-r--r--tensorflow/core/graph/algorithm_test.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc
index f67d5a2fd2..60a3e66aa1 100644
--- a/tensorflow/core/graph/algorithm_test.cc
+++ b/tensorflow/core/graph/algorithm_test.cc
@@ -36,6 +36,11 @@ namespace {
REGISTER_OP("TestParams").Output("o: float");
REGISTER_OP("TestInput").Output("a: float").Output("b: float");
REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
+REGISTER_OP("TestBinary")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("o: float");
// Compares that the order of nodes in 'inputs' respects the
// pair orders described in 'ordered_pairs'.
@@ -148,5 +153,52 @@ TEST(AlgorithmTest, ReversePostOrderStable) {
EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
}
}
+
+TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ string error;
+ Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
+ Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
+ Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
+ Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
+
+ Graph g(OpRegistry::Global());
+ TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
+
+ g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
+
+ std::vector<Node*> post_order;
+ auto edge_filter = [&](const Edge& e) {
+ return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
+ };
+
+ std::vector<Node*> expected_post_order = {
+ g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
+ g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
+
+ std::vector<Node*> expected_reverse_post_order = expected_post_order;
+ std::reverse(expected_reverse_post_order.begin(),
+ expected_reverse_post_order.end());
+
+ GetPostOrder(g, &post_order, /*stable_comparator=*/{},
+ /*edge_filter=*/edge_filter);
+
+ ASSERT_EQ(expected_post_order.size(), post_order.size());
+ for (int i = 0; i < post_order.size(); i++) {
+ CHECK_EQ(post_order[i], expected_post_order[i])
+ << post_order[i]->name() << " vs. " << expected_post_order[i]->name();
+ }
+
+ std::vector<Node*> reverse_post_order;
+ GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
+ /*edge_filter=*/edge_filter);
+
+ ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
+ for (int i = 0; i < reverse_post_order.size(); i++) {
+ CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
+ << reverse_post_order[i]->name() << " vs. "
+ << expected_reverse_post_order[i]->name();
+ }
+}
} // namespace
} // namespace tensorflow