aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_utils_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc229
1 files changed, 145 insertions, 84 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 00f66c9bc1..59ed79ab8f 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -23,134 +24,194 @@ namespace grappler {
namespace graph_utils {
namespace {
-class GraphUtilsTest : public ::testing::Test {};
-
-TEST_F(GraphUtilsTest, AddScalarConstNodeBool) {
- GraphDef graph;
- NodeDef* bool_node;
- TF_EXPECT_OK(AddScalarConstNode<bool>(true, &graph, &bool_node));
- EXPECT_TRUE(ContainsNodeWithName(bool_node->name(), graph));
+TEST(GraphUtilsTest, AddScalarConstNodeBool) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
+ EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph()));
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) {
- GraphDef graph;
- NodeDef* double_node;
- TF_EXPECT_OK(AddScalarConstNode<double>(3.14, &graph, &double_node));
- EXPECT_TRUE(ContainsNodeWithName(double_node->name(), graph));
+TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
+ EXPECT_TRUE(
+ ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph()));
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) {
- GraphDef graph;
- NodeDef* float_node;
- TF_EXPECT_OK(AddScalarConstNode<float>(3.14, &graph, &float_node));
- EXPECT_TRUE(ContainsNodeWithName(float_node->name(), graph));
+TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
+ EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph()));
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeInt) {
- GraphDef graph;
- NodeDef* int_node;
- TF_EXPECT_OK(AddScalarConstNode<int>(42, &graph, &int_node));
- EXPECT_TRUE(ContainsNodeWithName(int_node->name(), graph));
+TEST(GraphUtilsTest, AddScalarConstNodeInt) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
+ EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph()));
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) {
- GraphDef graph;
- NodeDef* int64_node;
- TF_EXPECT_OK(AddScalarConstNode<int64>(42, &graph, &int64_node));
- EXPECT_TRUE(ContainsNodeWithName(int64_node->name(), graph));
+TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
+ EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph()));
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeString) {
- GraphDef graph;
- NodeDef* string_node;
- TF_EXPECT_OK(AddScalarConstNode<StringPiece>("hello", &graph, &string_node));
- EXPECT_TRUE(ContainsNodeWithName(string_node->name(), graph));
+TEST(GraphUtilsTest, AddScalarConstNodeString) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
+ EXPECT_TRUE(
+ ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph()));
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
}
-TEST_F(GraphUtilsTest, Compare) {
- GraphDef graphA;
- GraphDef graphB;
- EXPECT_TRUE(Compare(graphA, graphB));
+TEST(GraphUtilsTest, Compare) {
+ GraphDef graph_def_a;
+ MutableGraphView graph_a(&graph_def_a);
+ GraphDef graph_def_b;
+ MutableGraphView graph_b(&graph_def_b);
+
+ EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
- NodeDef* nodeA;
- TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graphA, &nodeA));
- NodeDef* nodeB;
- TF_EXPECT_OK(AddNode("B", "OpB", {"A"}, {}, &graphA, &nodeB));
- EXPECT_FALSE(Compare(graphA, graphB));
+ AddNode("A", "OpA", {}, {}, &graph_a);
+ AddNode("B", "OpB", {"A"}, {}, &graph_a);
+ EXPECT_FALSE(Compare(graph_def_a, graph_def_b));
- graphB.mutable_node()->CopyFrom(graphA.node());
- EXPECT_TRUE(Compare(graphA, graphB));
+ graph_def_b.mutable_node()->CopyFrom(graph_def_a.node());
+ EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
}
-TEST_F(GraphUtilsTest, ContainsNodeWithName) {
- GraphDef graph;
- EXPECT_TRUE(!ContainsNodeWithName("A", graph));
+TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
- NodeDef* node;
- TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
- EXPECT_TRUE(ContainsNodeWithName("A", graph));
+ AddNode("A", "OpA", {}, {}, &graph);
+ EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph()));
- TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
- EXPECT_TRUE(!ContainsNodeWithName("A", graph));
+ graph.DeleteNodes({"A"});
+ EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
}
-TEST_F(GraphUtilsTest, ContainsNodeWithOp) {
- GraphDef graph;
- EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
+TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
+ FunctionDefLibrary library;
+ EXPECT_FALSE(ContainsGraphFunctionWithName("new_function", library));
+ FunctionDef* new_function = library.add_function();
+ SetUniqueGraphFunctionName("new_function", &library, new_function);
- NodeDef* node;
- TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
- EXPECT_TRUE(ContainsNodeWithOp("OpA", graph));
+ EXPECT_TRUE(
+ ContainsGraphFunctionWithName(new_function->signature().name(), library));
+}
- TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
- EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
+TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithName(
+ "weird_name_that_should_not_be_there", function));
+ EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
}
-TEST_F(GraphUtilsTest, FindNodeWithName) {
- GraphDef graph;
- EXPECT_EQ(FindNodeWithName("A", graph), -1);
+TEST(GraphUtilsTest, ContainsNodeWithOp) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
- NodeDef* node;
- TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
- EXPECT_NE(FindNodeWithName("A", graph), -1);
+ AddNode("A", "OpA", {}, {}, &graph);
+ EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph()));
- TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
- EXPECT_EQ(FindNodeWithName("A", graph), -1);
+ graph.DeleteNodes({"A"});
+ EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
}
-TEST_F(GraphUtilsTest, FindNodeWithOp) {
- GraphDef graph;
- EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
+TEST(GraphUtilsTest, FindGraphNodeWithName) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
- NodeDef* node;
- TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
- EXPECT_NE(FindNodeWithOp("OpA", graph), -1);
+ AddNode("A", "OpA", {}, {}, &graph);
+ EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
+
+ graph.DeleteNodes({"A"});
+ EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
+}
- TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
- EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
+TEST(GraphUtilsTest, FindFunctionWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
}
-TEST_F(GraphUtilsTest, SetUniqueName) {
- GraphDef graph;
+TEST(GraphUtilsTest, FindGraphFunctionWithName) {
+ FunctionDefLibrary library;
+ EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
+ FunctionDef* new_function = library.add_function();
+ SetUniqueGraphFunctionName("new_function", &library, new_function);
+
+ EXPECT_NE(
+ FindGraphFunctionWithName(new_function->signature().name(), library), -1);
+}
+
+TEST(GraphUtilsTest, FindNodeWithOp) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+ EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+
+ AddNode("A", "OpA", {}, {}, &graph);
+ EXPECT_NE(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
- NodeDef* node1;
- TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node1));
- NodeDef* node2;
- TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node2));
+ graph.DeleteNodes({"A"});
+ EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+}
+
+TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+
+ NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
+ NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
EXPECT_NE(node1->name(), node2->name());
- TF_EXPECT_OK(DeleteNodes({node1->name()}, &graph));
- NodeDef* node3;
- TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node3));
+ graph.DeleteNodes({node1->name()});
+ NodeDef* node3 = AddNode("", "A", {}, {}, &graph);
EXPECT_NE(node2->name(), node3->name());
}
+TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
+ FunctionDef function = test::function::XTimesTwo();
+ NodeDef node;
+ SetUniqueFunctionNodeName("abc", &function, &node);
+ for (const NodeDef& function_node : function.node_def()) {
+ EXPECT_NE(node.name(), function_node.name());
+ }
+ auto* new_node = function.add_node_def();
+ *new_node = node;
+
+ NodeDef other;
+ SetUniqueFunctionNodeName("abc", &function, &other);
+ EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
+ FunctionDefLibrary library;
+ FunctionDef* new_function = library.add_function();
+ SetUniqueGraphFunctionName("new_function", &library, new_function);
+
+ FunctionDef* other_function = library.add_function();
+ SetUniqueGraphFunctionName("new_function", &library, other_function);
+ EXPECT_NE(new_function->signature().name(),
+ other_function->signature().name());
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler