diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-28 22:11:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-28 22:13:59 -0700 |
commit | 6874e1ef40c4189d96c105227f60b507953f95d3 (patch) | |
tree | ca8c6d0f04aec4c2a32c99582084116606f9adaa /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 50839154899377f89367d851f6d1e2c45fcd6431 (diff) |
Convert exp(x-1) into expm1(x).
PiperOrigin-RevId: 202598404
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d0e6b04679..3f6c04a5b5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -274,6 +274,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.optimize_max_or_min_of_monotonic = true; } + + void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_expm1 = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2533,6 +2538,43 @@ TEST_F(ArithmeticOptimizerTest, Log1p) { CompareGraphs(want, got); } +TEST_F(ArithmeticOptimizerTest, Expm1) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2}); + auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2}); + auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2}); + auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2); + auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3); + Output out1 = ops::Exp(s.WithOpName("out1"), s12); + Output out2 = ops::Exp(s.WithOpName("out2"), s23); + + GrapplerItem item; + item.fetch = {"out1", "out2"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(2, tensors_expected.size()); + + GraphDef got; + ArithmeticOptimizer optimizer; + EnableOnlyExpm1(&optimizer); + OptimizeAndPrune(&optimizer, &item, &got); + auto tensors = EvaluateNodes(got, item.fetch); + EXPECT_EQ(2, tensors.size()); + + GraphDef want; + AddNode("x1", "Const", {}, {}, &want); + AddNode("x2", "Const", {}, {}, &want); + AddNode("x3", "Const", {}, {}, &want); + AddNode("s23", "Sub", {"x2", "x3"}, {}, &want); + AddNode("out1", "Expm1", + {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {}, + &want); + AddNode("out2", "Exp", {"s23"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); |