aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-08 14:42:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 14:46:40 -0800
commit04d33df3058a9e172659cb6ba9e5bc8f1412ec42 (patch)
tree00b6e468d8769ce74f442533b86853a17bf09f85 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
parent2dd2f9d04037b7c9b137e5ce3638506e1f013e13 (diff)
Add/AddN optimizer/rewriter
Collapse a sub-graph of Add/AddN operations of fully specified and identical shapes to a single AddN operation. PiperOrigin-RevId: 188392302
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.h')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h29
1 files changed, 25 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index afd538db40..9cff8ca9d0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -32,9 +32,14 @@ constexpr char kArithmeticOptimizer[] = "ArithmeticOptimizer";
// run a model.
class ArithmeticOptimizer : public GraphOptimizer {
public:
- ArithmeticOptimizer() : opt_level_(RewriterConfig::ON) {}
+ ArithmeticOptimizer()
+ : opt_level_(RewriterConfig::ON),
+ options_(ArithmeticOptimizerOptions::Default(RewriterConfig::ON)) {}
+
explicit ArithmeticOptimizer(RewriterConfig::Toggle opt_level)
- : opt_level_(opt_level) {}
+ : opt_level_(opt_level),
+ options_(ArithmeticOptimizerOptions::Default(opt_level)) {}
+
~ArithmeticOptimizer() override {}
string name() const override { return "arithmetic_optimizer"; };
@@ -46,6 +51,21 @@ class ArithmeticOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
+ friend class ArithmeticOptimizerTest;
+
+ // Granular control for arithmetic optimizer stages
+ struct ArithmeticOptimizerOptions {
+ // rewrite a tree of Add/AddN ops with a single AddN
+ bool enable_add_to_addn_combining;
+
+ // Choose which arithmetic optimizer stages will be enabled for a given
+ // optimization level by default.
+ static ArithmeticOptimizerOptions Default(
+ RewriterConfig::Toggle opt_level) {
+ return {/*enable_add_to_addn_combining*/ true};
+ }
+ };
+
// Returns true is a node with given name and the optimizer prefix already
// exists.
string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
@@ -97,13 +117,14 @@ class ArithmeticOptimizer : public GraphOptimizer {
SetVector<NodeDef*>* nodes_to_simplify);
RewriterConfig::Toggle opt_level_;
+ ArithmeticOptimizerOptions options_;
- bool fetch_nodes_known_;
+ bool fetch_nodes_known_ = false;
std::unordered_set<string> nodes_to_preserve_;
std::unique_ptr<NodeMap> node_map_;
FrameMap frame_map_;
std::unique_ptr<GraphProperties> graph_properties_;
- GraphDef* optimized_graph_; // Not owned.
+ GraphDef* optimized_graph_ = nullptr; // Not owned.
};
} // end namespace grappler