diff options
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 62 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.h | 3 |
3 files changed, 44 insertions, 26 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 85f877883c..e0ff9b17b1 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -107,8 +107,8 @@ TEST_F(ConstantFoldingTest, SimpleFolding) { EXPECT_EQ("Const", node_d.op()); std::vector<string> fetch = {"d"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch, {}); - auto tensors = EvaluateNodes(output, fetch, {}); + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors_expected.size()); EXPECT_EQ(1, tensors.size()); test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); @@ -193,10 +193,10 @@ TEST_F(ConstantFoldingTest, AddTree) { // Check that the result nodes have the expected value. std::vector<string> fetch = {"c3", "c20"}; - auto tensor_expected = EvaluateNodes(item.graph, fetch, {}); + auto tensor_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(fetch.size(), tensor_expected.size()); fetch = {"add_child", "mul_child"}; - auto tensors = EvaluateNodes(output, fetch, {}); + auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(fetch.size(), tensors.size()); for (int i = 0; i < fetch.size(); i++) { test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]); @@ -436,10 +436,10 @@ TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) { // Check that the reciprocals have the expected value. std::vector<string> fetch = {"cf_half"}; - auto tensor_expected = EvaluateNodes(item.graph, fetch, {}); + auto tensor_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(fetch.size(), tensor_expected.size()); fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"}; - auto tensors = EvaluateNodes(output, fetch, {}); + auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(fetch.size(), tensors.size()); for (int i = 0; i < fetch.size(); i++) { test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]); @@ -647,8 +647,8 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { EXPECT_EQ("Const", new_d.op()); std::vector<string> fetch = {"e", "f"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch, {}); - auto tensors = EvaluateNodes(output, fetch, {}); + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(fetch.size(), tensors_expected.size()); EXPECT_EQ(fetch.size(), tensors.size()); for (int i = 0; i < fetch.size(); i++) { @@ -671,7 +671,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { GrapplerItem item; item.fetch.push_back("e"); TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {}); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; @@ -688,8 +688,8 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { if (node.name() == "e") { EXPECT_EQ("Const", node.op()); ++found; - auto folded = EvaluateNodes(output, {"e"}, {}); - auto expected = EvaluateNodes(item.graph, {"e"}, {}); + auto folded = EvaluateNodes(output, {"e"}); + auto expected = EvaluateNodes(item.graph, {"e"}); EXPECT_EQ(1, expected.size()); EXPECT_EQ(1, folded.size()); test::ExpectTensorEqual<int>(folded[0], expected[0]); @@ -699,7 +699,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { } } EXPECT_EQ(1, found); - auto tensors = EvaluateNodes(output, item.fetch, {}); + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]); } @@ -735,8 +735,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) { if (node.name() == "i1") { EXPECT_EQ("Const", node.op()); ++found; - auto folded = EvaluateNodes(output, {"i1"}, {}); - auto expected = EvaluateNodes(item.graph, {"i1"}, {}); + auto folded = EvaluateNodes(output, {"i1"}); + auto expected = EvaluateNodes(item.graph, {"i1"}); EXPECT_EQ(1, expected.size()); EXPECT_EQ(1, folded.size()); test::ExpectTensorEqual<int>(folded[0], expected[0]); @@ -746,8 +746,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) { if (node.name() == "i2") { EXPECT_EQ("Const", node.op()); ++found; - auto folded = EvaluateNodes(output, {"i2"}, {}); - auto expected = EvaluateNodes(item.graph, {"i2"}, {}); + auto folded = EvaluateNodes(output, {"i2"}); + auto expected = EvaluateNodes(item.graph, {"i2"}); EXPECT_EQ(1, expected.size()); EXPECT_EQ(1, folded.size()); test::ExpectTensorEqual<int>(folded[0], expected[0]); @@ -775,7 +775,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) { GrapplerItem item; item.fetch.push_back("i2"); TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); @@ -794,6 +795,9 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) { EXPECT_EQ("^p2", node.input(1)); } } + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]); } TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) { @@ -865,8 +869,8 @@ TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) { } EXPECT_EQ(8, constant_folded); - auto expected = EvaluateNodes(item.graph, outputs, {}); - auto optimized = EvaluateNodes(output, outputs, {}); + auto expected = EvaluateNodes(item.graph, outputs); + auto optimized = EvaluateNodes(output, outputs); ASSERT_EQ(expected.size(), optimized.size()); for (int i = 0; i < expected.size(); ++i) { test::ExpectTensorEqual<int>(expected[i], optimized[i]); @@ -1293,7 +1297,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) { EXPECT_EQ(6, found_nodes); std::vector<string> fetch = {"out1", "idx1"}; - auto tensors = EvaluateNodes(output, fetch, {}); + auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(2, tensors.size()); const Tensor& out_value = tensors[0]; EXPECT_EQ(3 * 5, out_value.NumElements()); @@ -1803,6 +1807,12 @@ TEST_F(ConstantFoldingTest, LargeConstant) { EXPECT_EQ(2, found); EXPECT_GT(1024 * 1024, output.ByteSizeLong()); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); } TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) { @@ -1948,8 +1958,8 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) { } std::vector<string> fetch = {"acc0"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch, {}); - auto tensors = EvaluateNodes(output, fetch, {}); + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors_expected.size()); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); @@ -1983,7 +1993,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) { item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4", "concat5", "concat6", "concat7", "concat8", "concat9"}; - auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}, {}); + auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}); EXPECT_EQ(1, tensors_expected.size()); ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; @@ -2034,7 +2044,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) { } } - auto tensors = EvaluateNodes(output, {"concat0"}, {}); + auto tensors = EvaluateNodes(output, {"concat0"}); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } @@ -2132,8 +2142,8 @@ TEST_F(ConstantFoldingTest, TrivialPack) { } std::vector<string> fetch = {"stack"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch, {}); - auto tensors = EvaluateNodes(output, fetch, {}); + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors_expected.size()); EXPECT_EQ(1, tensors.size()); EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape()); diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 5c96359867..910b0acaef 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -40,6 +40,11 @@ GrapplerTest::GrapplerTest() { } std::vector<Tensor> GrapplerTest::EvaluateNodes( + const GraphDef& graph, const std::vector<string>& node_names) const { + return EvaluateNodes(graph, node_names, {}); +} + +std::vector<Tensor> GrapplerTest::EvaluateNodes( const GraphDef& graph, const std::vector<string>& node_names, const std::vector<std::pair<string, Tensor>>& inputs) const { std::unique_ptr<tensorflow::Session> session(NewSession(options_)); diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 4b160e7f16..3bc7bea454 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -35,6 +35,9 @@ class GrapplerTest : public ::testing::Test { protected: std::vector<Tensor> EvaluateNodes( + const GraphDef& graph, const std::vector<string>& node_names) const; + + std::vector<Tensor> EvaluateNodes( const GraphDef& graph, const std::vector<string>& node_names, const std::vector<std::pair<string, Tensor>>& inputs) const; |