diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-08-02 16:55:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-02 16:59:30 -0700 |
commit | bb3ed5ee461988f1020b9768a42ce27966ec08dc (patch) | |
tree | 0b2fdb35eda6c64dfc9a365587eee46b28c1d758 /tensorflow/core/grappler/mutable_graph_view_test.cc | |
parent | 4bc5c6c77daea8d5e60c22732b56495a0cd6c681 (diff) |
Experimental Cl which adds `LatencyStatsDataset` op after each `Dataset` op to record latency on each edge of dataset input pipeline.
PiperOrigin-RevId: 207190025
Diffstat (limited to 'tensorflow/core/grappler/mutable_graph_view_test.cc')
-rw-r--r-- | tensorflow/core/grappler/mutable_graph_view_test.cc | 67 |
1 files changed, 50 insertions, 17 deletions
diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc index f09dfb8271..2536bec35d 100644 --- a/tensorflow/core/grappler/mutable_graph_view_test.cc +++ b/tensorflow/core/grappler/mutable_graph_view_test.cc @@ -23,7 +23,18 @@ namespace tensorflow { namespace grappler { namespace { -TEST(MutableGraphViewTest, AddAndReplaceInput) { +bool FindChildWithName(const MutableGraphView& graph, + const string& output_port_name, + const string& input_name) { + GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0); + auto fanout = graph.GetFanout(output_port); + for (auto& input_port : fanout) { + if (input_port.node->name() == input_name) return true; + } + return false; +} + +TrivialTestGraphInputYielder SimpleGraph() { // This outputs simple graph like: // x // / \ @@ -35,7 +46,13 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) { // AddN AddN_1 // \ / // y - TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"}); + TrivialTestGraphInputYielder simple_graph(2, 2, 2, false, + {"/CPU:0", "/GPU:0"}); + return simple_graph; +} + +TEST(MutableGraphViewTest, AddAndReplaceInput) { + TrivialTestGraphInputYielder fake_input = SimpleGraph(); GrapplerItem item; CHECK(fake_input.NextItem(&item)); @@ -49,18 +66,7 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) { EXPECT_EQ("Square", fanin.node->name()); EXPECT_EQ(0, fanin.port_id); - auto find_child_with_name = [&graph](string output_port_name, - string input_name) { - GraphView::OutputPort output_port = - graph.GetOutputPort(output_port_name, 0); - auto fanout = graph.GetFanout(output_port); - for (auto& input_port : fanout) { - if (input_port.node->name() == input_name) return true; - } - return false; - }; - - EXPECT_FALSE(find_child_with_name("Square", "new_node")); + EXPECT_FALSE(FindChildWithName(graph, "Square", "new_node")); NodeDef new_node = *input.node; new_node.set_name("new_node"); @@ -70,13 +76,40 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) { EXPECT_NE(graph.GetNode("new_node"), nullptr); graph.ReplaceInput(*input.node, *node_in_graph); - EXPECT_TRUE(find_child_with_name("Square", "new_node")); - EXPECT_TRUE(find_child_with_name("new_node", "y")); + EXPECT_TRUE(FindChildWithName(graph, "Square", "new_node")); + EXPECT_TRUE(FindChildWithName(graph, "new_node", "y")); +} + +TEST(MutableGraphViewTest, InsertNodes) { + TrivialTestGraphInputYielder fake_input = SimpleGraph(); + + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + GraphDef new_graph = item.graph; + MutableGraphView graph(&new_graph); + + GraphView::InputPort input = graph.GetInputPort("AddN", 0); + + NodeDef new_node = *input.node; + new_node.set_name("new_node"); + new_node.set_input(0, input.node->name()); + + EXPECT_EQ(graph.GetNode("new_node"), nullptr); + graph.InsertNode(*input.node, std::move(new_node)); + EXPECT_NE(graph.GetNode("new_node"), nullptr); + EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN")); + EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN_1")); + EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN")); + EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN_1")); + EXPECT_TRUE(FindChildWithName(graph, "AddN", "new_node")); + EXPECT_TRUE(FindChildWithName(graph, "AddN_1", "y")); + EXPECT_TRUE(FindChildWithName(graph, "new_node", "y")); } TEST(MutableGraphViewTest, DeleteNodes) { // Outputs simple graph as described in first test. - TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"}); + TrivialTestGraphInputYielder fake_input = SimpleGraph(); GrapplerItem item; CHECK(fake_input.NextItem(&item)); |