diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-20 13:56:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 14:01:17 -0700 |
commit | 17dbe77f5ad47e8fd71924f12b3bc53c05afbacf (patch) | |
tree | 46142d37c97ca378139cb73785171903a74f3516 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | d388770922ad1afa95e55597a33836fe74035c75 (diff) |
Fix bug in Pow optimizer rule when broadcasting is involved.
Minor cleanup by moving the helper function ShapesEqual to GraphProperties and adding unit tests for it.
PiperOrigin-RevId: 213876779
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 88839d944c..77f3c64c65 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2474,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2}); auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2}); auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2}); + auto z = ops::Const(s.WithOpName("z"), {42.0f}, {}); + auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3}); + auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3}); Output out2 = ops::Pow(s.WithOpName("out2"), x, y2); Output out1 = ops::Pow(s.WithOpName("out1"), x, y1); Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5); @@ -2481,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5); Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1); Output out = ops::Pow(s.WithOpName("out"), x, y); + Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones); + Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros); GrapplerItem item; - item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"}; + item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", + "out_1", "out", "out_bcast1", "out_bcast2"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - EXPECT_EQ(7, tensors_expected.size()); + EXPECT_EQ(9, tensors_expected.size()); GraphDef got; ArithmeticOptimizer optimizer; EnableOnlyConvertPow(&optimizer); OptimizeAndPrune(&optimizer, &item, &got); auto tensors = EvaluateNodes(got, item.fetch); - EXPECT_EQ(7, tensors.size()); + EXPECT_EQ(9, tensors.size()); - for (int i = 0; i < 7; ++i) { + for (int i = 0; i < tensors.size(); ++i) { EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements()); test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6); } @@ -2509,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { AddNode("y_.5", "Const", {}, {}, &want); AddNode("y_1", "Const", {}, {}, &want); AddNode("y", "Const", {}, {}, &want); + AddNode("z", "Const", {}, {}, &want); + AddNode("ones", "Const", {}, {}, &want); + AddNode("zeros", "Const", {}, {}, &want); AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want); AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want); AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want); @@ -2517,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want); AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want); AddNode("out", "Pow", {"x", "y"}, {}, &want); + AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want); + AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want); CompareGraphs(want, got); } |