aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-16 08:53:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-16 08:56:40 -0700
commit5764747347c5a7b3e868ecc8943a397e304a0a92 (patch)
tree2f3be95ed6f81127c91545b40c478775318900ed /tensorflow/core/grappler
parent1c697bc9094365cf5dab1ec1550eba019dffa3b8 (diff)
Optimize max/min reductions over monotonic functions
PiperOrigin-RevId: 200843761
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/op_types.cc12
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc54
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc46
5 files changed, 114 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 2227904dbf..b4ddd61c29 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,6 +135,18 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
+bool IsElementWiseMonotonic(const NodeDef& node) {
+ static const std::unordered_set<string>* element_wise_monotonic_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Relu",
+ "Relu6",
+ "Sigmoid",
+ "Sqrt",
+ "Tanh",
+ }));
+ return element_wise_monotonic_ops->count(node.op()) > 0;
+}
+
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
bool IsEnter(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 7110a9c63d..2de7d8cc9a 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,6 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 9d500f8f54..d518685216 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2600,6 +2600,58 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
};
+// Performs conversions like:
+// Max(Sqrt(x)) => Sqrt(Max(x))
+// Checks for a max/min reduction over element-wise monotonic functions, such
+// as Sqrt, Sigmoid, Tanh, etc.
+class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
+ public:
+ explicit OptimizeMaxOrMinOfMonotonicStage(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
+ ctx_ext) {}
+ ~OptimizeMaxOrMinOfMonotonicStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMax(*node) || IsMin(*node);
+ }
+
+ Status TrySimplify(NodeDef* reduction_node,
+ string* simplified_node_name) override {
+ NodeDef* inner_function;
+ TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
+ // Optimize only if:
+ // 1. inner_function's Op is element-wise monotonic
+ // 2. inner_function's output is not being consumed elsewhere.
+ if (IsElementWiseMonotonic(*inner_function) &&
+ (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ // Swap the first inputs of the inner function Op & the reduction Op.
+ NodeDef* inner_input;
+ TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
+ reduction_node->set_input(0, inner_input->name());
+ UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ }
+ return Status::OK();
+ }
+
+ void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ const string& node_name = node->name();
+ const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (consumer->input(i) == node_name && consumer->name() != new_input) {
+ consumer->set_input(i, new_input);
+ ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2878,6 +2930,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
if (options_.convert_log1p)
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
+ if (options_.optimize_max_or_min_of_monotonic)
+ pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 9a6081dcd8..824ef35ef6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -63,6 +63,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool hoist_common_factor_out_of_aggregation = true;
bool hoist_cwise_unary_chains = false;
bool minimize_broadcasts = true;
+ bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
bool remove_identity_transpose = true;
bool remove_involution = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 177c237fe7..e1d55cdf5f 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -269,6 +269,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.convert_log1p = true;
}
+
+ void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.optimize_max_or_min_of_monotonic = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -3125,5 +3130,46 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
}
}
+TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "sqrt") {
+ EXPECT_EQ("Sqrt", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Max", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
} // namespace grappler
} // namespace tensorflow