diff options
Diffstat (limited to 'tensorflow/core/graph/graph_test.cc')
-rw-r--r-- | tensorflow/core/graph/graph_test.cc | 64 |
1 files changed, 62 insertions, 2 deletions
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index e5d57facaa..d1c89a48bd 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -118,6 +118,25 @@ class GraphTest : public ::testing::Test { LOG(FATAL) << name; } + bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, + const Node* dst) { + for (const Edge *e : dst->in_edges()) { + if (e->IsControlEdge() && + e->src() == src && + e->src_output() == Graph::kControlSlot && + e->dst_input() == Graph::kControlSlot) { + return true; + } + } + std::string control_edge_name = strings::StrCat("^", src->name()); + for (int i = 0; i < dst->def().input_size(); ++i) { + if (dst->def().input(i) == control_edge_name) { + return true; + } + } + return false; + } + Graph graph_; private: @@ -458,8 +477,8 @@ TEST_F(GraphTest, AddControlEdge) { EXPECT_TRUE(edge == nullptr); EXPECT_EQ(b->def().input_size(), 2); - // Can add redundant control edge with create_duplicate. - edge = graph_.AddControlEdge(a, b, /*create_duplicate=*/true); + // Can add redundant control edge with allow_duplicates. + edge = graph_.AddControlEdge(a, b, /*allow_duplicates=*/true); EXPECT_TRUE(edge != nullptr); // create_duplicate causes the NodeDef not to be updated. ASSERT_EQ(b->def().input_size(), 2); @@ -477,6 +496,47 @@ TEST_F(GraphTest, AddControlEdge) { EXPECT_EQ(b->def().input_size(), 2); } +TEST_F(GraphTest, RemoveControlEdge) { + FromGraphDef( + "node { name: 'A' op: 'OneOutput' }" + "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }" + "node { name: 'C' op: 'NoOp' } "); + Node* a = FindNode("A"); + Node* b = FindNode("B"); + Node* c = FindNode("C"); + + // Add a control edge. + const Edge* edge_1 = graph_.AddControlEdge(c, a); + const Edge* edge_2 = graph_.AddControlEdge(a, b); + ASSERT_TRUE(edge_1 != nullptr); + ASSERT_TRUE(edge_2 != nullptr); + + ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(c, a)); + ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b)); + + graph_.RemoveControlEdge(edge_1); + ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a)); + ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b)); + + graph_.RemoveControlEdge(edge_2); + ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a)); + ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(a, b)); + + // Test removing a duplicate control edge. + // Note that unless allow_duplicates is true, the duplicate edge + // will not be added. That's why we expect edge_4 to be a null + // pointer. We are not testing with allow_duplicates set to true, + // as that is a highly unlikely use case that does not make much + // sense. + const Edge* edge_3 = graph_.AddControlEdge(c, a); + const Edge* edge_4 = graph_.AddControlEdge(c, a); + ASSERT_TRUE(edge_3 != nullptr); + ASSERT_TRUE(edge_4 == nullptr); + + graph_.RemoveControlEdge(edge_3); + ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a)); +} + TEST_F(GraphTest, UpdateEdge) { // Build a little graph Node* a = FromNodeDef("A", "OneOutput", 0); |