aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-20 13:56:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:01:17 -0700
commit17dbe77f5ad47e8fd71924f12b3bc53c05afbacf (patch)
tree46142d37c97ca378139cb73785171903a74f3516 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parentd388770922ad1afa95e55597a33836fe74035c75 (diff)
Fix bug in Pow optimizer rule when broadcasting is involved.
Minor cleanup by moving the helper function ShapesEqual to GraphProperties and adding unit tests for it. PiperOrigin-RevId: 213876779
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc19
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 88839d944c..77f3c64c65 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2474,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
+ auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
+ auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@@ -2481,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
Output out = ops::Pow(s.WithOpName("out"), x, y);
+ Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
+ Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item;
- item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
+ "out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(7, tensors_expected.size());
+ EXPECT_EQ(9, tensors_expected.size());
GraphDef got;
ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(7, tensors.size());
+ EXPECT_EQ(9, tensors.size());
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
@@ -2509,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("y_.5", "Const", {}, {}, &want);
AddNode("y_1", "Const", {}, {}, &want);
AddNode("y", "Const", {}, {}, &want);
+ AddNode("z", "Const", {}, {}, &want);
+ AddNode("ones", "Const", {}, {}, &want);
+ AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@@ -2517,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
AddNode("out", "Pow", {"x", "y"}, {}, &want);
+ AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
+ AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
CompareGraphs(want, got);
}