aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc164
1 files changed, 70 insertions, 94 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
index 3c1d8d5359..a46c504ac4 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -27,25 +27,21 @@ namespace {
TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
GrapplerItem item;
- GraphDef *graph = &item.graph;
- NodeDef *start_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
- NodeDef *stop_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
- NodeDef *step_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+ MutableGraphView graph(&item.graph);
+
+ NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
+ NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
+ NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> range_inputs(3);
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
std::vector<std::pair<string, AttrValue>> range_attrs;
- NodeDef *range_node;
- TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
- range_attrs, graph, &range_node));
- NodeDef *captured_input_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
- "hello", graph, &captured_input_node));
+ NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
+ range_attrs, &graph);
+ NodeDef *captured_input_node =
+ graph_utils::AddScalarConstNode<StringPiece>("hello", &graph);
NodeDef *map_node;
{
@@ -59,13 +55,11 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
AttrValue args_attr;
SetAttrValue("Targuments", &args_attr);
map_attrs[1] = std::make_pair("Targuments", args_attr);
- TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs,
- graph, &map_node));
+ map_node =
+ graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs, &graph);
}
- NodeDef *batch_size_node;
- TF_ASSERT_OK(
- graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph);
NodeDef *batch_node;
{
std::vector<string> batch_inputs(2);
@@ -78,16 +72,18 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
AttrValue types_attr;
SetAttrValue("output_types", &types_attr);
batch_attrs[1] = std::make_pair("output_types", types_attr);
- TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
- batch_attrs, graph, &batch_node));
+ batch_node = graph_utils::AddNode("", "BatchDataset", batch_inputs,
+ batch_attrs, &graph);
}
MapAndBatchFusion optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
- EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
- EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+ EXPECT_FALSE(
+ graph_utils::ContainsGraphNodeWithName(map_node->name(), output));
+ EXPECT_FALSE(
+ graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node =
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
@@ -96,11 +92,11 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
NodeDef num_parallel_calls_node = output.node(
- graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+ graph_utils::FindGraphNodeWithName(map_and_batch_node.input(3), output));
EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0),
1);
NodeDef drop_remainder_node = output.node(
- graph_utils::FindNodeWithName(map_and_batch_node.input(4), output));
+ graph_utils::FindGraphNodeWithName(map_and_batch_node.input(4), output));
EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false);
EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"),
map_node->attr().at("f")));
@@ -114,25 +110,20 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
GrapplerItem item;
- GraphDef *graph = &item.graph;
- NodeDef *start_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
- NodeDef *stop_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
- NodeDef *step_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+ MutableGraphView graph(&item.graph);
+ NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
+ NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
+ NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> range_inputs(3);
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
std::vector<std::pair<string, AttrValue>> range_attrs;
- NodeDef *range_node;
- TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
- range_attrs, graph, &range_node));
- NodeDef *captured_input_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
- "hello", graph, &captured_input_node));
+ NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
+ range_attrs, &graph);
+ NodeDef *captured_input_node =
+ graph_utils::AddScalarConstNode<StringPiece>("hello", &graph);
NodeDef *map_node;
{
@@ -146,16 +137,13 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
AttrValue args_attr;
SetAttrValue("Targuments", &args_attr);
map_attrs[1] = std::make_pair("Targuments", args_attr);
- TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs,
- graph, &map_node));
+ map_node =
+ graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs, &graph);
}
- NodeDef *batch_size_node;
- TF_ASSERT_OK(
- graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
- NodeDef *drop_remainder_node;
- TF_ASSERT_OK(
- graph_utils::AddScalarConstNode<bool>(true, graph, &drop_remainder_node));
+ NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph);
+ NodeDef *drop_remainder_node =
+ graph_utils::AddScalarConstNode<bool>(true, &graph);
NodeDef *batch_node;
{
std::vector<string> batch_inputs(3);
@@ -169,16 +157,18 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
AttrValue types_attr;
SetAttrValue("output_types", &types_attr);
batch_attrs[1] = std::make_pair("output_types", types_attr);
- TF_ASSERT_OK(graph_utils::AddNode("", "BatchDatasetV2", batch_inputs,
- batch_attrs, graph, &batch_node));
+ batch_node = graph_utils::AddNode("", "BatchDatasetV2", batch_inputs,
+ batch_attrs, &graph);
}
MapAndBatchFusion optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
- EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
- EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+ EXPECT_FALSE(
+ graph_utils::ContainsGraphNodeWithName(map_node->name(), output));
+ EXPECT_FALSE(
+ graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node =
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
@@ -187,7 +177,7 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
NodeDef num_parallel_calls_node = output.node(
- graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+ graph_utils::FindGraphNodeWithName(map_and_batch_node.input(3), output));
EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0),
1);
EXPECT_EQ(map_and_batch_node.input(4), batch_node->input(2));
@@ -203,28 +193,22 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
GrapplerItem item;
- GraphDef *graph = &item.graph;
- NodeDef *start_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
- NodeDef *stop_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
- NodeDef *step_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+ MutableGraphView graph(&item.graph);
+ NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
+ NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
+ NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> range_inputs(3);
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
std::vector<std::pair<string, AttrValue>> range_attrs;
- NodeDef *range_node;
- TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
- range_attrs, graph, &range_node));
- NodeDef *captured_input_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
- "hello", graph, &captured_input_node));
- NodeDef *num_parallel_calls_node;
- TF_ASSERT_OK(
- graph_utils::AddScalarConstNode<int>(2, graph, &num_parallel_calls_node));
+ NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
+ range_attrs, &graph);
+ NodeDef *captured_input_node =
+ graph_utils::AddScalarConstNode<StringPiece>("hello", &graph);
+ NodeDef *num_parallel_calls_node =
+ graph_utils::AddScalarConstNode<int>(2, &graph);
NodeDef *map_node;
{
@@ -239,13 +223,11 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
AttrValue args_attr;
SetAttrValue("Targuments", &args_attr);
map_attrs[1] = std::make_pair("Targuments", args_attr);
- TF_ASSERT_OK(graph_utils::AddNode("", "ParallelMapDataset", map_inputs,
- map_attrs, graph, &map_node));
+ map_node = graph_utils::AddNode("", "ParallelMapDataset", map_inputs,
+ map_attrs, &graph);
}
- NodeDef *batch_size_node;
- TF_ASSERT_OK(
- graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph);
NodeDef *batch_node;
{
std::vector<string> batch_inputs(2);
@@ -258,16 +240,18 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
AttrValue types_attr;
SetAttrValue("output_types", &types_attr);
batch_attrs[1] = std::make_pair("output_types", types_attr);
- TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
- batch_attrs, graph, &batch_node));
+ batch_node = graph_utils::AddNode("", "BatchDataset", batch_inputs,
+ batch_attrs, &graph);
}
MapAndBatchFusion optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
- EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
- EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+ EXPECT_FALSE(
+ graph_utils::ContainsGraphNodeWithName(map_node->name(), output));
+ EXPECT_FALSE(
+ graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
NodeDef map_and_batch_node =
output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
@@ -276,11 +260,11 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
NodeDef num_parallel_calls_node2 = output.node(
- graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+ graph_utils::FindGraphNodeWithName(map_and_batch_node.input(3), output));
EXPECT_EQ(num_parallel_calls_node2.attr().at("value").tensor().int64_val(0),
2);
NodeDef drop_remainder_node = output.node(
- graph_utils::FindNodeWithName(map_and_batch_node.input(4), output));
+ graph_utils::FindGraphNodeWithName(map_and_batch_node.input(4), output));
EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false);
EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"),
map_node->attr().at("f")));
@@ -294,27 +278,21 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
TEST(MapAndBatchFusionTest, NoChange) {
GrapplerItem item;
- GraphDef *graph = &item.graph;
+ MutableGraphView graph(&item.graph);
- NodeDef *start_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
- NodeDef *stop_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
- NodeDef *step_node;
- TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+ NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
+ NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
+ NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
std::vector<string> range_inputs(3);
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
std::vector<std::pair<string, AttrValue>> range_attrs;
- NodeDef *range_node;
- TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
- range_attrs, graph, &range_node));
+ NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
+ range_attrs, &graph);
- NodeDef *batch_size_node;
- TF_ASSERT_OK(
- graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ NodeDef *batch_size_node = graph_utils::AddScalarConstNode<int64>(5, &graph);
std::vector<string> batch_inputs(2);
batch_inputs[0] = range_node->name();
batch_inputs[1] = batch_size_node->name();
@@ -325,15 +303,13 @@ TEST(MapAndBatchFusionTest, NoChange) {
AttrValue types_attr;
SetAttrValue("output_types", &types_attr);
batch_attrs[1] = std::make_pair("output_types", types_attr);
- NodeDef *batch_node;
- TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
- batch_attrs, graph, &batch_node));
+ graph_utils::AddNode("", "BatchDataset", batch_inputs, batch_attrs, &graph);
MapAndBatchFusion optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
- EXPECT_TRUE(graph_utils::Compare(*graph, output));
+ EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output));
}
} // namespace