diff options
Diffstat (limited to 'tensorflow/core/graph/algorithm_test.cc')
-rw-r--r-- | tensorflow/core/graph/algorithm_test.cc | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc index f67d5a2fd2..60a3e66aa1 100644 --- a/tensorflow/core/graph/algorithm_test.cc +++ b/tensorflow/core/graph/algorithm_test.cc @@ -36,6 +36,11 @@ namespace { REGISTER_OP("TestParams").Output("o: float"); REGISTER_OP("TestInput").Output("a: float").Output("b: float"); REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("TestUnary").Input("a: float").Output("o: float"); +REGISTER_OP("TestBinary") + .Input("a: float") + .Input("b: float") + .Output("o: float"); // Compares that the order of nodes in 'inputs' respects the // pair orders described in 'ordered_pairs'. @@ -148,5 +153,52 @@ TEST(AlgorithmTest, ReversePostOrderStable) { EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error)); } } + +TEST(AlgorithmTest, PostOrderWithEdgeFilter) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + string error; + Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0")); + Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1")); + Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2")); + Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1); + + std::vector<Node*> post_order; + auto edge_filter = [&](const Edge& e) { + return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id()); + }; + + std::vector<Node*> expected_post_order = { + g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()), + g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()}; + + std::vector<Node*> expected_reverse_post_order = expected_post_order; + std::reverse(expected_reverse_post_order.begin(), + expected_reverse_post_order.end()); + + GetPostOrder(g, &post_order, /*stable_comparator=*/{}, + /*edge_filter=*/edge_filter); + + ASSERT_EQ(expected_post_order.size(), post_order.size()); + for (int i = 0; i < post_order.size(); i++) { + CHECK_EQ(post_order[i], expected_post_order[i]) + << post_order[i]->name() << " vs. " << expected_post_order[i]->name(); + } + + std::vector<Node*> reverse_post_order; + GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{}, + /*edge_filter=*/edge_filter); + + ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size()); + for (int i = 0; i < reverse_post_order.size(); i++) { + CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i]) + << reverse_post_order[i]->name() << " vs. " + << expected_reverse_post_order[i]->name(); + } +} } // namespace } // namespace tensorflow |