aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-22 08:02:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-22 08:05:05 -0700
commit96f4fefefab793f9673252b1937f23e5e3a9801a (patch)
treecbc6a37a53b7cdc0413bb40a987b2d5d7d496a45 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent2e4a4b4a3b994abb118c247ed9ecc7cc79a26950 (diff)
Automated g4 rollback of changelist 197527651
PiperOrigin-RevId: 197562826
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc57
1 files changed, 0 insertions, 57 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 8b8eedfbb3..64fdc8a83b 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -173,11 +173,6 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
}
- void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
- DisableAllStages(optimizer);
- optimizer->options_.convert_pow = true;
- }
-
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_idempotent = true;
@@ -2248,58 +2243,6 @@ TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
}
}
-TEST_F(ArithmeticOptimizerTest, ConvertPow) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
- auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
- auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
- auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
- auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
- auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
- auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
- auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
- Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
- Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
- Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
- Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
- Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
- Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
- Output out = ops::Pow(s.WithOpName("out"), x, y);
-
- GrapplerItem item;
- item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(7, tensors_expected.size());
-
- GraphDef got;
- ArithmeticOptimizer optimizer;
- EnableOnlyConvertPow(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &got);
- auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(7, tensors.size());
-
- GraphDef want;
- AddNode("x", "Const", {}, {}, &want);
- AddNode("y2", "Const", {}, {}, &want);
- AddNode("y1", "Const", {}, {}, &want);
- AddNode("y.5", "Const", {}, {}, &want);
- AddNode("y0", "Const", {}, {}, &want);
- AddNode("y_.5", "Const", {}, {}, &want);
- AddNode("y_1", "Const", {}, {}, &want);
- AddNode("y", "Const", {}, {}, &want);
- AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
- AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
- AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
- AddNode("out0", "Const",
- {AsControlDependency("x"), AsControlDependency("y0")}, {}, &want);
- AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
- AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
- AddNode("out", "Pow", {"x", "y"}, {}, &want);
-
- CompareGraphs(want, got);
-}
-
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();