aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 22:11:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 22:13:59 -0700
commit6874e1ef40c4189d96c105227f60b507953f95d3 (patch)
treeca8c6d0f04aec4c2a32c99582084116606f9adaa /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent50839154899377f89367d851f6d1e2c45fcd6431 (diff)
Convert exp(x-1) into expm1(x).
PiperOrigin-RevId: 202598404
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index d0e6b04679..3f6c04a5b5 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -274,6 +274,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
+
+ void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_expm1 = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2533,6 +2538,43 @@ TEST_F(ArithmeticOptimizerTest, Log1p) {
CompareGraphs(want, got);
}
+TEST_F(ArithmeticOptimizerTest, Expm1) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
+ auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
+ auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
+ auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2);
+ auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3);
+ Output out1 = ops::Exp(s.WithOpName("out1"), s12);
+ Output out2 = ops::Exp(s.WithOpName("out2"), s23);
+
+ GrapplerItem item;
+ item.fetch = {"out1", "out2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef got;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyExpm1(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &got);
+ auto tensors = EvaluateNodes(got, item.fetch);
+ EXPECT_EQ(2, tensors.size());
+
+ GraphDef want;
+ AddNode("x1", "Const", {}, {}, &want);
+ AddNode("x2", "Const", {}, {}, &want);
+ AddNode("x3", "Const", {}, {}, &want);
+ AddNode("s23", "Sub", {"x2", "x3"}, {}, &want);
+ AddNode("out1", "Expm1",
+ {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {},
+ &want);
+ AddNode("out2", "Exp", {"s23"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();