aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc62
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc5
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.h3
3 files changed, 44 insertions, 26 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 85f877883c..e0ff9b17b1 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -107,8 +107,8 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
EXPECT_EQ("Const", node_d.op());
std::vector<string> fetch = {"d"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
- auto tensors = EvaluateNodes(output, fetch, {});
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
@@ -193,10 +193,10 @@ TEST_F(ConstantFoldingTest, AddTree) {
// Check that the result nodes have the expected value.
std::vector<string> fetch = {"c3", "c20"};
- auto tensor_expected = EvaluateNodes(item.graph, fetch, {});
+ auto tensor_expected = EvaluateNodes(item.graph, fetch);
EXPECT_EQ(fetch.size(), tensor_expected.size());
fetch = {"add_child", "mul_child"};
- auto tensors = EvaluateNodes(output, fetch, {});
+ auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
@@ -436,10 +436,10 @@ TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
// Check that the reciprocals have the expected value.
std::vector<string> fetch = {"cf_half"};
- auto tensor_expected = EvaluateNodes(item.graph, fetch, {});
+ auto tensor_expected = EvaluateNodes(item.graph, fetch);
EXPECT_EQ(fetch.size(), tensor_expected.size());
fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
- auto tensors = EvaluateNodes(output, fetch, {});
+ auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
@@ -647,8 +647,8 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
EXPECT_EQ("Const", new_d.op());
std::vector<string> fetch = {"e", "f"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
- auto tensors = EvaluateNodes(output, fetch, {});
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(fetch.size(), tensors_expected.size());
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
@@ -671,7 +671,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
GrapplerItem item;
item.fetch.push_back("e");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
@@ -688,8 +688,8 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
if (node.name() == "e") {
EXPECT_EQ("Const", node.op());
++found;
- auto folded = EvaluateNodes(output, {"e"}, {});
- auto expected = EvaluateNodes(item.graph, {"e"}, {});
+ auto folded = EvaluateNodes(output, {"e"});
+ auto expected = EvaluateNodes(item.graph, {"e"});
EXPECT_EQ(1, expected.size());
EXPECT_EQ(1, folded.size());
test::ExpectTensorEqual<int>(folded[0], expected[0]);
@@ -699,7 +699,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
}
}
EXPECT_EQ(1, found);
- auto tensors = EvaluateNodes(output, item.fetch, {});
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
}
@@ -735,8 +735,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
if (node.name() == "i1") {
EXPECT_EQ("Const", node.op());
++found;
- auto folded = EvaluateNodes(output, {"i1"}, {});
- auto expected = EvaluateNodes(item.graph, {"i1"}, {});
+ auto folded = EvaluateNodes(output, {"i1"});
+ auto expected = EvaluateNodes(item.graph, {"i1"});
EXPECT_EQ(1, expected.size());
EXPECT_EQ(1, folded.size());
test::ExpectTensorEqual<int>(folded[0], expected[0]);
@@ -746,8 +746,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
if (node.name() == "i2") {
EXPECT_EQ("Const", node.op());
++found;
- auto folded = EvaluateNodes(output, {"i2"}, {});
- auto expected = EvaluateNodes(item.graph, {"i2"}, {});
+ auto folded = EvaluateNodes(output, {"i2"});
+ auto expected = EvaluateNodes(item.graph, {"i2"});
EXPECT_EQ(1, expected.size());
EXPECT_EQ(1, folded.size());
test::ExpectTensorEqual<int>(folded[0], expected[0]);
@@ -775,7 +775,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
GrapplerItem item;
item.fetch.push_back("i2");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
@@ -794,6 +795,9 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
EXPECT_EQ("^p2", node.input(1));
}
}
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
}
TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
@@ -865,8 +869,8 @@ TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
}
EXPECT_EQ(8, constant_folded);
- auto expected = EvaluateNodes(item.graph, outputs, {});
- auto optimized = EvaluateNodes(output, outputs, {});
+ auto expected = EvaluateNodes(item.graph, outputs);
+ auto optimized = EvaluateNodes(output, outputs);
ASSERT_EQ(expected.size(), optimized.size());
for (int i = 0; i < expected.size(); ++i) {
test::ExpectTensorEqual<int>(expected[i], optimized[i]);
@@ -1293,7 +1297,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
EXPECT_EQ(6, found_nodes);
std::vector<string> fetch = {"out1", "idx1"};
- auto tensors = EvaluateNodes(output, fetch, {});
+ auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(2, tensors.size());
const Tensor& out_value = tensors[0];
EXPECT_EQ(3 * 5, out_value.NumElements());
@@ -1803,6 +1807,12 @@ TEST_F(ConstantFoldingTest, LargeConstant) {
EXPECT_EQ(2, found);
EXPECT_GT(1024 * 1024, output.ByteSizeLong());
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
@@ -1948,8 +1958,8 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
}
std::vector<string> fetch = {"acc0"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
- auto tensors = EvaluateNodes(output, fetch, {});
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
@@ -1983,7 +1993,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
"concat5", "concat6", "concat7", "concat8", "concat9"};
- auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}, {});
+ auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
@@ -2034,7 +2044,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
}
}
- auto tensors = EvaluateNodes(output, {"concat0"}, {});
+ auto tensors = EvaluateNodes(output, {"concat0"});
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -2132,8 +2142,8 @@ TEST_F(ConstantFoldingTest, TrivialPack) {
}
std::vector<string> fetch = {"stack"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
- auto tensors = EvaluateNodes(output, fetch, {});
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 5c96359867..910b0acaef 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -40,6 +40,11 @@ GrapplerTest::GrapplerTest() {
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(
+ const GraphDef& graph, const std::vector<string>& node_names) const {
+ return EvaluateNodes(graph, node_names, {});
+}
+
+std::vector<Tensor> GrapplerTest::EvaluateNodes(
const GraphDef& graph, const std::vector<string>& node_names,
const std::vector<std::pair<string, Tensor>>& inputs) const {
std::unique_ptr<tensorflow::Session> session(NewSession(options_));
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h
index 4b160e7f16..3bc7bea454 100644
--- a/tensorflow/core/grappler/utils/grappler_test.h
+++ b/tensorflow/core/grappler/utils/grappler_test.h
@@ -35,6 +35,9 @@ class GrapplerTest : public ::testing::Test {
protected:
std::vector<Tensor> EvaluateNodes(
+ const GraphDef& graph, const std::vector<string>& node_names) const;
+
+ std::vector<Tensor> EvaluateNodes(
const GraphDef& graph, const std::vector<string>& node_names,
const std::vector<std::pair<string, Tensor>>& inputs) const;