diff options
author | 2018-06-16 08:53:17 -0700 | |
---|---|---|
committer | 2018-06-16 08:56:40 -0700 | |
commit | 5764747347c5a7b3e868ecc8943a397e304a0a92 (patch) | |
tree | 2f3be95ed6f81127c91545b40c478775318900ed /tensorflow/core/grappler | |
parent | 1c697bc9094365cf5dab1ec1550eba019dffa3b8 (diff) |
Optimize max/min reductions over monotonic functions
PiperOrigin-RevId: 200843761
Diffstat (limited to 'tensorflow/core/grappler')
5 files changed, 114 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 2227904dbf..b4ddd61c29 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -135,6 +135,18 @@ bool IsDequeueOp(const NodeDef& node) { bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } +bool IsElementWiseMonotonic(const NodeDef& node) { + static const std::unordered_set<string>* element_wise_monotonic_ops = + CHECK_NOTNULL((new std::unordered_set<string>{ + "Relu", + "Relu6", + "Sigmoid", + "Sqrt", + "Tanh", + })); + return element_wise_monotonic_ops->count(node.op()) > 0; +} + bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } bool IsEnter(const NodeDef& node) { diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 7110a9c63d..2de7d8cc9a 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,6 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); +bool IsElementWiseMonotonic(const NodeDef& node); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 9d500f8f54..d518685216 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2600,6 +2600,58 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } }; +// Performs conversions like: +// Max(Sqrt(x)) => Sqrt(Max(x)) +// Checks for a max/min reduction over element-wise monotonic functions, such +// as Sqrt, Sigmoid, Tanh, etc. +class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { + public: + explicit OptimizeMaxOrMinOfMonotonicStage( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx, + ctx_ext) {} + ~OptimizeMaxOrMinOfMonotonicStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMax(*node) || IsMin(*node); + } + + Status TrySimplify(NodeDef* reduction_node, + string* simplified_node_name) override { + NodeDef* inner_function; + TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function)); + // Optimize only if: + // 1. inner_function's Op is element-wise monotonic + // 2. inner_function's output is not being consumed elsewhere. + if (IsElementWiseMonotonic(*inner_function) && + (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) { + // Swap the first inputs of the inner function Op & the reduction Op. + NodeDef* inner_input; + TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input)); + inner_function->set_input(0, reduction_node->name()); + UpdateConsumersAvoidingLoop(inner_function, reduction_node->name()); + reduction_node->set_input(0, inner_input->name()); + UpdateConsumersAvoidingLoop(reduction_node, inner_function->name()); + } + return Status::OK(); + } + + void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) { + const string& node_name = node->name(); + const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); + for (NodeDef* consumer : consumers) { + for (int i = 0; i < consumer->input_size(); ++i) { + if (consumer->input(i) == node_name && consumer->name() != new_input) { + consumer->set_input(i, new_input); + ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); + } + } + AddToOptimizationQueue(consumer); + } + } +}; + } // namespace class UniqueNodes { @@ -2878,6 +2930,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); if (options_.convert_log1p) pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); + if (options_.optimize_max_or_min_of_monotonic) + pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 9a6081dcd8..824ef35ef6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -63,6 +63,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool hoist_common_factor_out_of_aggregation = true; bool hoist_cwise_unary_chains = false; bool minimize_broadcasts = true; + bool optimize_max_or_min_of_monotonic = true; bool remove_idempotent = true; bool remove_identity_transpose = true; bool remove_involution = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 177c237fe7..e1d55cdf5f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -269,6 +269,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.convert_log1p = true; } + + void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.optimize_max_or_min_of_monotonic = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -3125,5 +3130,46 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) { } } +TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x); + Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0}); + Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size(), output.node_size()); + // Check if the inputs are switched + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "sqrt") { + EXPECT_EQ("Sqrt", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("reduce_max", node.input(0)); + ++required_node_count; + } else if (node.name() == "reduce_max") { + EXPECT_EQ("Max", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); +} + } // namespace grappler } // namespace tensorflow |