aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc66
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index d0e6b04679..c387b00303 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -141,6 +141,9 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.dedup_computations = false;
options.combine_add_to_addn = false;
options.convert_sqrt_div_to_rsqrt_mul = false;
+ options.convert_pow = false;
+ options.convert_log1p = false;
+ options.optimize_max_or_min_of_monotonic = false;
options.fold_conjugate_into_transpose = false;
options.fold_multiply_into_conv = false;
options.fold_transpose_into_matmul = false;
@@ -158,6 +161,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.reorder_cast_and_transpose = false;
options.replace_mul_with_square = false;
options.simplify_aggregation = false;
+ options.unary_ops_composition = false;
optimizer->options_ = options;
}
@@ -274,6 +278,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
+
+ void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.unary_ops_composition = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -3159,5 +3168,62 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
EXPECT_EQ(2, required_node_count);
}
+TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output log = ops::Log(s.WithOpName("log"), sqrt);
+ Output relu = ops::Relu(s.WithOpName("relu"), log);
+ Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ // Place all nodes on CPU.
+ for (int i = 0; i < item.graph.node_size(); ++i) {
+ item.graph.mutable_node(i)->set_device("/device:CPU:0");
+ }
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyUnaryOpsComposition(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ EXPECT_EQ(3, output.node_size());
+
+ // Check that Sqrt/Log/Relu were replaced with a single op.
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "final_out") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("relu/unary_ops_composition", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "relu/unary_ops_composition") {
+ EXPECT_EQ("_UnaryOpsComposition", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+
+ auto op_names = node.attr().at("op_names").list().s();
+ EXPECT_EQ(3, op_names.size());
+ EXPECT_EQ("Sqrt", op_names[0]);
+ EXPECT_EQ("Log", op_names[1]);
+ EXPECT_EQ("Relu", op_names[2]);
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
} // namespace grappler
} // namespace tensorflow