diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc | 164 |
1 files changed, 70 insertions, 94 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc index 3c1d8d5359..a46c504ac4 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc @@ -27,25 +27,21 @@ namespace { TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { GrapplerItem item; - GraphDef *graph = &item.graph; - NodeDef *start_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node)); - NodeDef *stop_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node)); - NodeDef *step_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node)); + MutableGraphView graph(&item.graph); + + NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph); + NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph); + NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph); std::vector<string> range_inputs(3); range_inputs[0] = start_node->name(); range_inputs[1] = stop_node->name(); range_inputs[2] = step_node->name(); std::vector<std::pair<string, AttrValue>> range_attrs; - NodeDef *range_node; - TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, - range_attrs, graph, &range_node)); - NodeDef *captured_input_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>( - "hello", graph, &captured_input_node)); + NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, + range_attrs, &graph); + NodeDef *captured_input_node = + graph_utils::AddScalarConstNode<StringPiece>("hello", &graph); NodeDef *map_node; { @@ -59,13 +55,11 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { AttrValue args_attr; SetAttrValue("Targuments", &args_attr); map_attrs[1] = std::make_pair("Targuments", args_attr); - TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs, - graph, &map_node)); + map_node = + graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs, &graph); } - NodeDef *batch_size_node; - TF_ASSERT_OK( - graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node)); + NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph); NodeDef *batch_node; { std::vector<string> batch_inputs(2); @@ -78,16 +72,18 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { AttrValue types_attr; SetAttrValue("output_types", &types_attr); batch_attrs[1] = std::make_pair("output_types", types_attr); - TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, - batch_attrs, graph, &batch_node)); + batch_node = graph_utils::AddNode("", "BatchDataset", batch_inputs, + batch_attrs, &graph); } MapAndBatchFusion optimizer; GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output)); - EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(map_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(batch_node->name(), output)); EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output)); NodeDef map_and_batch_node = output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output)); @@ -96,11 +92,11 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1)); EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1)); NodeDef num_parallel_calls_node = output.node( - graph_utils::FindNodeWithName(map_and_batch_node.input(3), output)); + graph_utils::FindGraphNodeWithName(map_and_batch_node.input(3), output)); EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0), 1); NodeDef drop_remainder_node = output.node( - graph_utils::FindNodeWithName(map_and_batch_node.input(4), output)); + graph_utils::FindGraphNodeWithName(map_and_batch_node.input(4), output)); EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false); EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"), map_node->attr().at("f"))); @@ -114,25 +110,20 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) { GrapplerItem item; - GraphDef *graph = &item.graph; - NodeDef *start_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node)); - NodeDef *stop_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node)); - NodeDef *step_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node)); + MutableGraphView graph(&item.graph); + NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph); + NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph); + NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph); std::vector<string> range_inputs(3); range_inputs[0] = start_node->name(); range_inputs[1] = stop_node->name(); range_inputs[2] = step_node->name(); std::vector<std::pair<string, AttrValue>> range_attrs; - NodeDef *range_node; - TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, - range_attrs, graph, &range_node)); - NodeDef *captured_input_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>( - "hello", graph, &captured_input_node)); + NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, + range_attrs, &graph); + NodeDef *captured_input_node = + graph_utils::AddScalarConstNode<StringPiece>("hello", &graph); NodeDef *map_node; { @@ -146,16 +137,13 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) { AttrValue args_attr; SetAttrValue("Targuments", &args_attr); map_attrs[1] = std::make_pair("Targuments", args_attr); - TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs, - graph, &map_node)); + map_node = + graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs, &graph); } - NodeDef *batch_size_node; - TF_ASSERT_OK( - graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node)); - NodeDef *drop_remainder_node; - TF_ASSERT_OK( - graph_utils::AddScalarConstNode<bool>(true, graph, &drop_remainder_node)); + NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph); + NodeDef *drop_remainder_node = + graph_utils::AddScalarConstNode<bool>(true, &graph); NodeDef *batch_node; { std::vector<string> batch_inputs(3); @@ -169,16 +157,18 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) { AttrValue types_attr; SetAttrValue("output_types", &types_attr); batch_attrs[1] = std::make_pair("output_types", types_attr); - TF_ASSERT_OK(graph_utils::AddNode("", "BatchDatasetV2", batch_inputs, - batch_attrs, graph, &batch_node)); + batch_node = graph_utils::AddNode("", "BatchDatasetV2", batch_inputs, + batch_attrs, &graph); } MapAndBatchFusion optimizer; GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output)); - EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(map_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(batch_node->name(), output)); EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output)); NodeDef map_and_batch_node = output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output)); @@ -187,7 +177,7 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) { EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1)); EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1)); NodeDef num_parallel_calls_node = output.node( - graph_utils::FindNodeWithName(map_and_batch_node.input(3), output)); + graph_utils::FindGraphNodeWithName(map_and_batch_node.input(3), output)); EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0), 1); EXPECT_EQ(map_and_batch_node.input(4), batch_node->input(2)); @@ -203,28 +193,22 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) { TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { GrapplerItem item; - GraphDef *graph = &item.graph; - NodeDef *start_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node)); - NodeDef *stop_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node)); - NodeDef *step_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node)); + MutableGraphView graph(&item.graph); + NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph); + NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph); + NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph); std::vector<string> range_inputs(3); range_inputs[0] = start_node->name(); range_inputs[1] = stop_node->name(); range_inputs[2] = step_node->name(); std::vector<std::pair<string, AttrValue>> range_attrs; - NodeDef *range_node; - TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, - range_attrs, graph, &range_node)); - NodeDef *captured_input_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>( - "hello", graph, &captured_input_node)); - NodeDef *num_parallel_calls_node; - TF_ASSERT_OK( - graph_utils::AddScalarConstNode<int>(2, graph, &num_parallel_calls_node)); + NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, + range_attrs, &graph); + NodeDef *captured_input_node = + graph_utils::AddScalarConstNode<StringPiece>("hello", &graph); + NodeDef *num_parallel_calls_node = + graph_utils::AddScalarConstNode<int>(2, &graph); NodeDef *map_node; { @@ -239,13 +223,11 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { AttrValue args_attr; SetAttrValue("Targuments", &args_attr); map_attrs[1] = std::make_pair("Targuments", args_attr); - TF_ASSERT_OK(graph_utils::AddNode("", "ParallelMapDataset", map_inputs, - map_attrs, graph, &map_node)); + map_node = graph_utils::AddNode("", "ParallelMapDataset", map_inputs, + map_attrs, &graph); } - NodeDef *batch_size_node; - TF_ASSERT_OK( - graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node)); + NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph); NodeDef *batch_node; { std::vector<string> batch_inputs(2); @@ -258,16 +240,18 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { AttrValue types_attr; SetAttrValue("output_types", &types_attr); batch_attrs[1] = std::make_pair("output_types", types_attr); - TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, - batch_attrs, graph, &batch_node)); + batch_node = graph_utils::AddNode("", "BatchDataset", batch_inputs, + batch_attrs, &graph); } MapAndBatchFusion optimizer; GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output)); - EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(map_node->name(), output)); + EXPECT_FALSE( + graph_utils::ContainsGraphNodeWithName(batch_node->name(), output)); EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output)); NodeDef map_and_batch_node = output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output)); @@ -276,11 +260,11 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1)); EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1)); NodeDef num_parallel_calls_node2 = output.node( - graph_utils::FindNodeWithName(map_and_batch_node.input(3), output)); + graph_utils::FindGraphNodeWithName(map_and_batch_node.input(3), output)); EXPECT_EQ(num_parallel_calls_node2.attr().at("value").tensor().int64_val(0), 2); NodeDef drop_remainder_node = output.node( - graph_utils::FindNodeWithName(map_and_batch_node.input(4), output)); + graph_utils::FindGraphNodeWithName(map_and_batch_node.input(4), output)); EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false); EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"), map_node->attr().at("f"))); @@ -294,27 +278,21 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { TEST(MapAndBatchFusionTest, NoChange) { GrapplerItem item; - GraphDef *graph = &item.graph; + MutableGraphView graph(&item.graph); - NodeDef *start_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node)); - NodeDef *stop_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node)); - NodeDef *step_node; - TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node)); + NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph); + NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph); + NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph); std::vector<string> range_inputs(3); range_inputs[0] = start_node->name(); range_inputs[1] = stop_node->name(); range_inputs[2] = step_node->name(); std::vector<std::pair<string, AttrValue>> range_attrs; - NodeDef *range_node; - TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, - range_attrs, graph, &range_node)); + NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, + range_attrs, &graph); - NodeDef *batch_size_node; - TF_ASSERT_OK( - graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node)); + NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph); std::vector<string> batch_inputs(2); batch_inputs[0] = range_node->name(); batch_inputs[1] = batch_size_node->name(); @@ -325,15 +303,13 @@ TEST(MapAndBatchFusionTest, NoChange) { AttrValue types_attr; SetAttrValue("output_types", &types_attr); batch_attrs[1] = std::make_pair("output_types", types_attr); - NodeDef *batch_node; - TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, - batch_attrs, graph, &batch_node)); + graph_utils::AddNode("", "BatchDataset", batch_inputs, batch_attrs, &graph); MapAndBatchFusion optimizer; GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_TRUE(graph_utils::Compare(*graph, output)); + EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output)); } } // namespace |