aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/mutable_graph_view_test.cc
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-08-02 16:55:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 16:59:30 -0700
commitbb3ed5ee461988f1020b9768a42ce27966ec08dc (patch)
tree0b2fdb35eda6c64dfc9a365587eee46b28c1d758 /tensorflow/core/grappler/mutable_graph_view_test.cc
parent4bc5c6c77daea8d5e60c22732b56495a0cd6c681 (diff)
Experimental Cl which adds `LatencyStatsDataset` op after each `Dataset` op to record latency on each edge of dataset input pipeline.
PiperOrigin-RevId: 207190025
Diffstat (limited to 'tensorflow/core/grappler/mutable_graph_view_test.cc')
-rw-r--r--tensorflow/core/grappler/mutable_graph_view_test.cc67
1 files changed, 50 insertions, 17 deletions
diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc
index f09dfb8271..2536bec35d 100644
--- a/tensorflow/core/grappler/mutable_graph_view_test.cc
+++ b/tensorflow/core/grappler/mutable_graph_view_test.cc
@@ -23,7 +23,18 @@ namespace tensorflow {
namespace grappler {
namespace {
-TEST(MutableGraphViewTest, AddAndReplaceInput) {
+bool FindChildWithName(const MutableGraphView& graph,
+ const string& output_port_name,
+ const string& input_name) {
+ GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0);
+ auto fanout = graph.GetFanout(output_port);
+ for (auto& input_port : fanout) {
+ if (input_port.node->name() == input_name) return true;
+ }
+ return false;
+}
+
+TrivialTestGraphInputYielder SimpleGraph() {
// This outputs simple graph like:
// x
// / \
@@ -35,7 +46,13 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
// AddN AddN_1
// \ /
// y
- TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"});
+ TrivialTestGraphInputYielder simple_graph(2, 2, 2, false,
+ {"/CPU:0", "/GPU:0"});
+ return simple_graph;
+}
+
+TEST(MutableGraphViewTest, AddAndReplaceInput) {
+ TrivialTestGraphInputYielder fake_input = SimpleGraph();
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
@@ -49,18 +66,7 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
EXPECT_EQ("Square", fanin.node->name());
EXPECT_EQ(0, fanin.port_id);
- auto find_child_with_name = [&graph](string output_port_name,
- string input_name) {
- GraphView::OutputPort output_port =
- graph.GetOutputPort(output_port_name, 0);
- auto fanout = graph.GetFanout(output_port);
- for (auto& input_port : fanout) {
- if (input_port.node->name() == input_name) return true;
- }
- return false;
- };
-
- EXPECT_FALSE(find_child_with_name("Square", "new_node"));
+ EXPECT_FALSE(FindChildWithName(graph, "Square", "new_node"));
NodeDef new_node = *input.node;
new_node.set_name("new_node");
@@ -70,13 +76,40 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
EXPECT_NE(graph.GetNode("new_node"), nullptr);
graph.ReplaceInput(*input.node, *node_in_graph);
- EXPECT_TRUE(find_child_with_name("Square", "new_node"));
- EXPECT_TRUE(find_child_with_name("new_node", "y"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square", "new_node"));
+ EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
+}
+
+TEST(MutableGraphViewTest, InsertNodes) {
+ TrivialTestGraphInputYielder fake_input = SimpleGraph();
+
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ GraphDef new_graph = item.graph;
+ MutableGraphView graph(&new_graph);
+
+ GraphView::InputPort input = graph.GetInputPort("AddN", 0);
+
+ NodeDef new_node = *input.node;
+ new_node.set_name("new_node");
+ new_node.set_input(0, input.node->name());
+
+ EXPECT_EQ(graph.GetNode("new_node"), nullptr);
+ graph.InsertNode(*input.node, std::move(new_node));
+ EXPECT_NE(graph.GetNode("new_node"), nullptr);
+ EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN_1"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN_1"));
+ EXPECT_TRUE(FindChildWithName(graph, "AddN", "new_node"));
+ EXPECT_TRUE(FindChildWithName(graph, "AddN_1", "y"));
+ EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
}
TEST(MutableGraphViewTest, DeleteNodes) {
// Outputs simple graph as described in first test.
- TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"});
+ TrivialTestGraphInputYielder fake_input = SimpleGraph();
GrapplerItem item;
CHECK(fake_input.NextItem(&item));