diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_utils_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/graph_utils_test.cc | 229 |
1 files changed, 145 insertions, 84 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 00f66c9bc1..59ed79ab8f 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -23,134 +24,194 @@ namespace grappler { namespace graph_utils { namespace { -class GraphUtilsTest : public ::testing::Test {}; - -TEST_F(GraphUtilsTest, AddScalarConstNodeBool) { - GraphDef graph; - NodeDef* bool_node; - TF_EXPECT_OK(AddScalarConstNode<bool>(true, &graph, &bool_node)); - EXPECT_TRUE(ContainsNodeWithName(bool_node->name(), graph)); +TEST(GraphUtilsTest, AddScalarConstNodeBool) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph); + EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph())); EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true); } -TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) { - GraphDef graph; - NodeDef* double_node; - TF_EXPECT_OK(AddScalarConstNode<double>(3.14, &graph, &double_node)); - EXPECT_TRUE(ContainsNodeWithName(double_node->name(), graph)); +TEST(GraphUtilsTest, AddScalarConstNodeDouble) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph); + EXPECT_TRUE( + ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph())); EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14); } -TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) { - GraphDef graph; - NodeDef* float_node; - TF_EXPECT_OK(AddScalarConstNode<float>(3.14, &graph, &float_node)); - EXPECT_TRUE(ContainsNodeWithName(float_node->name(), graph)); +TEST(GraphUtilsTest, AddScalarConstNodeFloat) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph); + EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph())); EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14); } -TEST_F(GraphUtilsTest, AddScalarConstNodeInt) { - GraphDef graph; - NodeDef* int_node; - TF_EXPECT_OK(AddScalarConstNode<int>(42, &graph, &int_node)); - EXPECT_TRUE(ContainsNodeWithName(int_node->name(), graph)); +TEST(GraphUtilsTest, AddScalarConstNodeInt) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* int_node = AddScalarConstNode<int>(42, &graph); + EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph())); EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42); } -TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) { - GraphDef graph; - NodeDef* int64_node; - TF_EXPECT_OK(AddScalarConstNode<int64>(42, &graph, &int64_node)); - EXPECT_TRUE(ContainsNodeWithName(int64_node->name(), graph)); +TEST(GraphUtilsTest, AddScalarConstNodeInt64) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph); + EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph())); EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42); } -TEST_F(GraphUtilsTest, AddScalarConstNodeString) { - GraphDef graph; - NodeDef* string_node; - TF_EXPECT_OK(AddScalarConstNode<StringPiece>("hello", &graph, &string_node)); - EXPECT_TRUE(ContainsNodeWithName(string_node->name(), graph)); +TEST(GraphUtilsTest, AddScalarConstNodeString) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph); + EXPECT_TRUE( + ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph())); EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); } -TEST_F(GraphUtilsTest, Compare) { - GraphDef graphA; - GraphDef graphB; - EXPECT_TRUE(Compare(graphA, graphB)); +TEST(GraphUtilsTest, Compare) { + GraphDef graph_def_a; + MutableGraphView graph_a(&graph_def_a); + GraphDef graph_def_b; + MutableGraphView graph_b(&graph_def_b); + + EXPECT_TRUE(Compare(graph_def_a, graph_def_b)); - NodeDef* nodeA; - TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graphA, &nodeA)); - NodeDef* nodeB; - TF_EXPECT_OK(AddNode("B", "OpB", {"A"}, {}, &graphA, &nodeB)); - EXPECT_FALSE(Compare(graphA, graphB)); + AddNode("A", "OpA", {}, {}, &graph_a); + AddNode("B", "OpB", {"A"}, {}, &graph_a); + EXPECT_FALSE(Compare(graph_def_a, graph_def_b)); - graphB.mutable_node()->CopyFrom(graphA.node()); - EXPECT_TRUE(Compare(graphA, graphB)); + graph_def_b.mutable_node()->CopyFrom(graph_def_a.node()); + EXPECT_TRUE(Compare(graph_def_a, graph_def_b)); } -TEST_F(GraphUtilsTest, ContainsNodeWithName) { - GraphDef graph; - EXPECT_TRUE(!ContainsNodeWithName("A", graph)); +TEST(GraphUtilsTest, ContainsGraphNodeWithName) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); - NodeDef* node; - TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); - EXPECT_TRUE(ContainsNodeWithName("A", graph)); + AddNode("A", "OpA", {}, {}, &graph); + EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph())); - TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); - EXPECT_TRUE(!ContainsNodeWithName("A", graph)); + graph.DeleteNodes({"A"}); + EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); } -TEST_F(GraphUtilsTest, ContainsNodeWithOp) { - GraphDef graph; - EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph)); +TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { + FunctionDefLibrary library; + EXPECT_FALSE(ContainsGraphFunctionWithName("new_function", library)); + FunctionDef* new_function = library.add_function(); + SetUniqueGraphFunctionName("new_function", &library, new_function); - NodeDef* node; - TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); - EXPECT_TRUE(ContainsNodeWithOp("OpA", graph)); + EXPECT_TRUE( + ContainsGraphFunctionWithName(new_function->signature().name(), library)); +} - TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); - EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph)); +TEST(GraphUtilsTest, ContainsFunctionNodeWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_FALSE(ContainsFunctionNodeWithName( + "weird_name_that_should_not_be_there", function)); + EXPECT_TRUE(ContainsFunctionNodeWithName("two", function)); } -TEST_F(GraphUtilsTest, FindNodeWithName) { - GraphDef graph; - EXPECT_EQ(FindNodeWithName("A", graph), -1); +TEST(GraphUtilsTest, ContainsNodeWithOp) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); - NodeDef* node; - TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); - EXPECT_NE(FindNodeWithName("A", graph), -1); + AddNode("A", "OpA", {}, {}, &graph); + EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph())); - TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); - EXPECT_EQ(FindNodeWithName("A", graph), -1); + graph.DeleteNodes({"A"}); + EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); } -TEST_F(GraphUtilsTest, FindNodeWithOp) { - GraphDef graph; - EXPECT_EQ(FindNodeWithOp("OpA", graph), -1); +TEST(GraphUtilsTest, FindGraphNodeWithName) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); - NodeDef* node; - TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node)); - EXPECT_NE(FindNodeWithOp("OpA", graph), -1); + AddNode("A", "OpA", {}, {}, &graph); + EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1); + + graph.DeleteNodes({"A"}); + EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); +} - TF_EXPECT_OK(DeleteNodes({"A"}, &graph)); - EXPECT_EQ(FindNodeWithOp("OpA", graph), -1); +TEST(GraphUtilsTest, FindFunctionWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ( + FindFunctionNodeWithName("weird_name_that_should_not_be_there", function), + -1); + EXPECT_NE(FindFunctionNodeWithName("two", function), -1); } -TEST_F(GraphUtilsTest, SetUniqueName) { - GraphDef graph; +TEST(GraphUtilsTest, FindGraphFunctionWithName) { + FunctionDefLibrary library; + EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1); + FunctionDef* new_function = library.add_function(); + SetUniqueGraphFunctionName("new_function", &library, new_function); + + EXPECT_NE( + FindGraphFunctionWithName(new_function->signature().name(), library), -1); +} + +TEST(GraphUtilsTest, FindNodeWithOp) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1); + + AddNode("A", "OpA", {}, {}, &graph); + EXPECT_NE(FindNodeWithOp("OpA", *graph.GetGraph()), -1); - NodeDef* node1; - TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node1)); - NodeDef* node2; - TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node2)); + graph.DeleteNodes({"A"}); + EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1); +} + +TEST(GraphUtilsTest, SetUniqueGraphNodeName) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); + + NodeDef* node1 = AddNode("", "A", {}, {}, &graph); + NodeDef* node2 = AddNode("", "A", {}, {}, &graph); EXPECT_NE(node1->name(), node2->name()); - TF_EXPECT_OK(DeleteNodes({node1->name()}, &graph)); - NodeDef* node3; - TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node3)); + graph.DeleteNodes({node1->name()}); + NodeDef* node3 = AddNode("", "A", {}, {}, &graph); EXPECT_NE(node2->name(), node3->name()); } +TEST(GraphUtilsTest, SetUniqueFunctionNodeName) { + FunctionDef function = test::function::XTimesTwo(); + NodeDef node; + SetUniqueFunctionNodeName("abc", &function, &node); + for (const NodeDef& function_node : function.node_def()) { + EXPECT_NE(node.name(), function_node.name()); + } + auto* new_node = function.add_node_def(); + *new_node = node; + + NodeDef other; + SetUniqueFunctionNodeName("abc", &function, &other); + EXPECT_NE(other.name(), new_node->name()); +} + +TEST(GraphUtilsTest, SetUniqueGraphFunctionName) { + FunctionDefLibrary library; + FunctionDef* new_function = library.add_function(); + SetUniqueGraphFunctionName("new_function", &library, new_function); + + FunctionDef* other_function = library.add_function(); + SetUniqueGraphFunctionName("new_function", &library, other_function); + EXPECT_NE(new_function->signature().name(), + other_function->signature().name()); +} + } // namespace } // namespace graph_utils } // namespace grappler |