diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-16 08:53:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-16 08:56:40 -0700 |
commit | 5764747347c5a7b3e868ecc8943a397e304a0a92 (patch) | |
tree | 2f3be95ed6f81127c91545b40c478775318900ed /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 1c697bc9094365cf5dab1ec1550eba019dffa3b8 (diff) |
Optimize max/min reductions over monotonic functions
PiperOrigin-RevId: 200843761
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 9d500f8f54..d518685216 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2600,6 +2600,58 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } }; +// Performs conversions like: +// Max(Sqrt(x)) => Sqrt(Max(x)) +// Checks for a max/min reduction over element-wise monotonic functions, such +// as Sqrt, Sigmoid, Tanh, etc. +class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { + public: + explicit OptimizeMaxOrMinOfMonotonicStage( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx, + ctx_ext) {} + ~OptimizeMaxOrMinOfMonotonicStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMax(*node) || IsMin(*node); + } + + Status TrySimplify(NodeDef* reduction_node, + string* simplified_node_name) override { + NodeDef* inner_function; + TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function)); + // Optimize only if: + // 1. inner_function's Op is element-wise monotonic + // 2. inner_function's output is not being consumed elsewhere. + if (IsElementWiseMonotonic(*inner_function) && + (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) { + // Swap the first inputs of the inner function Op & the reduction Op. + NodeDef* inner_input; + TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input)); + inner_function->set_input(0, reduction_node->name()); + UpdateConsumersAvoidingLoop(inner_function, reduction_node->name()); + reduction_node->set_input(0, inner_input->name()); + UpdateConsumersAvoidingLoop(reduction_node, inner_function->name()); + } + return Status::OK(); + } + + void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) { + const string& node_name = node->name(); + const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name); + for (NodeDef* consumer : consumers) { + for (int i = 0; i < consumer->input_size(); ++i) { + if (consumer->input(i) == node_name && consumer->name() != new_input) { + consumer->set_input(i, new_input); + ctx().node_map->UpdateInput(consumer->name(), node_name, new_input); + } + } + AddToOptimizationQueue(consumer); + } + } +}; + } // namespace class UniqueNodes { @@ -2878,6 +2930,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); if (options_.convert_log1p) pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); + if (options_.optimize_max_or_min_of_monotonic) + pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); |