aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-07 15:47:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 17:44:56 -0700
commit4a9beef315c3e456e7f087b5b3205df99f4a0876 (patch)
treee0e298dab8364f6a2c8f9af4ef0174bf2088472a /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent37b8860e302d73845e74e1bfb6c3cb59207f2d77 (diff)
Add EvaluateNodes to tests: RemoveIdentityTransposesMultipleOutputs, RemoveTransposesWithControlDependency, CombineBitcasts, CombineAndRemoveBitcasts, RemoveRedundantCast
PiperOrigin-RevId: 195735234
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 741cc135a1..067adb359c 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -1166,6 +1166,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveIdentityTranspose(&optimizer);
@@ -1178,6 +1183,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
EXPECT_EQ(node.input(2), "Split:2");
}
}
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
@@ -1194,6 +1203,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
+ item.feed = {{"Placeholder", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveIdentityTranspose(&optimizer);
@@ -1204,6 +1218,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
EXPECT_EQ(2, outputs_node->input_size());
EXPECT_EQ(outputs_node->input(0), "outputs_const");
EXPECT_EQ(outputs_node->input(1), "^Placeholder");
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
@@ -1450,6 +1468,11 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
@@ -1461,6 +1484,10 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
EXPECT_EQ(3, output.node_size());
EXPECT_EQ(1, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
@@ -1475,6 +1502,11 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantBitcast(&optimizer);
@@ -1486,6 +1518,10 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Bitcast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
@@ -1499,6 +1535,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
+ item.feed = {{"inputs", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveRedundantCast(&optimizer);
@@ -1510,6 +1551,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
EXPECT_EQ(2, output.node_size());
EXPECT_EQ(0, CountOpNodes(output, "Cast"));
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {