aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-16 08:53:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-16 08:56:40 -0700
commit5764747347c5a7b3e868ecc8943a397e304a0a92 (patch)
tree2f3be95ed6f81127c91545b40c478775318900ed /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent1c697bc9094365cf5dab1ec1550eba019dffa3b8 (diff)
Optimize max/min reductions over monotonic functions
PiperOrigin-RevId: 200843761
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc54
1 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 9d500f8f54..d518685216 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2600,6 +2600,58 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
};
+// Performs conversions like:
+// Max(Sqrt(x)) => Sqrt(Max(x))
+// Checks for a max/min reduction over element-wise monotonic functions, such
+// as Sqrt, Sigmoid, Tanh, etc.
+class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
+ public:
+ explicit OptimizeMaxOrMinOfMonotonicStage(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
+ ctx_ext) {}
+ ~OptimizeMaxOrMinOfMonotonicStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMax(*node) || IsMin(*node);
+ }
+
+ Status TrySimplify(NodeDef* reduction_node,
+ string* simplified_node_name) override {
+ NodeDef* inner_function;
+ TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
+ // Optimize only if:
+ // 1. inner_function's Op is element-wise monotonic
+ // 2. inner_function's output is not being consumed elsewhere.
+ if (IsElementWiseMonotonic(*inner_function) &&
+ (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ // Swap the first inputs of the inner function Op & the reduction Op.
+ NodeDef* inner_input;
+ TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
+ reduction_node->set_input(0, inner_input->name());
+ UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ }
+ return Status::OK();
+ }
+
+ void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ const string& node_name = node->name();
+ const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (consumer->input(i) == node_name && consumer->name() != new_input) {
+ consumer->set_input(i, new_input);
+ ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2878,6 +2930,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
if (options_.convert_log1p)
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
+ if (options_.optimize_max_or_min_of_monotonic)
+ pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");