diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d0e6b04679..c387b00303 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -141,6 +141,9 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.dedup_computations = false; options.combine_add_to_addn = false; options.convert_sqrt_div_to_rsqrt_mul = false; + options.convert_pow = false; + options.convert_log1p = false; + options.optimize_max_or_min_of_monotonic = false; options.fold_conjugate_into_transpose = false; options.fold_multiply_into_conv = false; options.fold_transpose_into_matmul = false; @@ -158,6 +161,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.reorder_cast_and_transpose = false; options.replace_mul_with_square = false; options.simplify_aggregation = false; + options.unary_ops_composition = false; optimizer->options_ = options; } @@ -274,6 +278,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.optimize_max_or_min_of_monotonic = true; } + + void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.unary_ops_composition = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -3159,5 +3168,62 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) { EXPECT_EQ(2, required_node_count); } +TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x); + Output log = ops::Log(s.WithOpName("log"), sqrt); + Output relu = ops::Relu(s.WithOpName("relu"), log); + Output final_out = ops::Identity(s.WithOpName("final_out"), relu); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyUnaryOpsComposition(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + EXPECT_EQ(3, output.node_size()); + + // Check that Sqrt/Log/Relu were replaced with a single op. + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "final_out") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("relu/unary_ops_composition", node.input(0)); + ++required_node_count; + } else if (node.name() == "relu/unary_ops_composition") { + EXPECT_EQ("_UnaryOpsComposition", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + + auto op_names = node.attr().at("op_names").list().s(); + EXPECT_EQ(3, op_names.size()); + EXPECT_EQ("Sqrt", op_names[0]); + EXPECT_EQ("Log", op_names[1]); + EXPECT_EQ("Relu", op_names[2]); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); + + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); +} + } // namespace grappler } // namespace tensorflow |