diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-07 15:47:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 17:44:56 -0700 |
commit | 4a9beef315c3e456e7f087b5b3205df99f4a0876 (patch) | |
tree | e0e298dab8364f6a2c8f9af4ef0174bf2088472a /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 37b8860e302d73845e74e1bfb6c3cb59207f2d77 (diff) |
Add EvaluateNodes to tests: RemoveIdentityTransposesMultipleOutputs, RemoveTransposesWithControlDependency, CombineBitcasts, CombineAndRemoveBitcasts, RemoveRedundantCast
PiperOrigin-RevId: 195735234
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 741cc135a1..067adb359c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -1166,6 +1166,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveIdentityTranspose(&optimizer); @@ -1178,6 +1183,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) { EXPECT_EQ(node.input(2), "Split:2"); } } + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { @@ -1194,6 +1203,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3})); + item.feed = {{"Placeholder", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveIdentityTranspose(&optimizer); @@ -1204,6 +1218,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { EXPECT_EQ(2, outputs_node->input_size()); EXPECT_EQ(outputs_node->input(0), "outputs_const"); EXPECT_EQ(outputs_node->input(1), "^Placeholder"); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { @@ -1450,6 +1468,11 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveRedundantBitcast(&optimizer); @@ -1461,6 +1484,10 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { EXPECT_EQ(3, output.node_size()); EXPECT_EQ(1, CountOpNodes(output, "Bitcast")); EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2")); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { @@ -1475,6 +1502,11 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveRedundantBitcast(&optimizer); @@ -1486,6 +1518,10 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { EXPECT_EQ(2, output.node_size()); EXPECT_EQ(0, CountOpNodes(output, "Bitcast")); EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { @@ -1499,6 +1535,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveRedundantCast(&optimizer); @@ -1510,6 +1551,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { EXPECT_EQ(2, output.node_size()); EXPECT_EQ(0, CountOpNodes(output, "Cast")); EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) { |