diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-08 14:42:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-08 14:46:40 -0800 |
commit | 04d33df3058a9e172659cb6ba9e5bc8f1412ec42 (patch) | |
tree | 00b6e468d8769ce74f442533b86853a17bf09f85 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.h | |
parent | 2dd2f9d04037b7c9b137e5ce3638506e1f013e13 (diff) |
Add/AddN optimizer/rewriter
Collapse a sub-graph of Add/AddN operations of fully specified
and identical shapes to a single AddN operation.
PiperOrigin-RevId: 188392302
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.h')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.h | 29 |
1 files changed, 25 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index afd538db40..9cff8ca9d0 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -32,9 +32,14 @@ constexpr char kArithmeticOptimizer[] = "ArithmeticOptimizer"; // run a model. class ArithmeticOptimizer : public GraphOptimizer { public: - ArithmeticOptimizer() : opt_level_(RewriterConfig::ON) {} + ArithmeticOptimizer() + : opt_level_(RewriterConfig::ON), + options_(ArithmeticOptimizerOptions::Default(RewriterConfig::ON)) {} + explicit ArithmeticOptimizer(RewriterConfig::Toggle opt_level) - : opt_level_(opt_level) {} + : opt_level_(opt_level), + options_(ArithmeticOptimizerOptions::Default(opt_level)) {} + ~ArithmeticOptimizer() override {} string name() const override { return "arithmetic_optimizer"; }; @@ -46,6 +51,21 @@ class ArithmeticOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: + friend class ArithmeticOptimizerTest; + + // Granular control for arithmetic optimizer stages + struct ArithmeticOptimizerOptions { + // rewrite a tree of Add/AddN ops with a single AddN + bool enable_add_to_addn_combining; + + // Choose which arithmetic optimizer stages will be enabled for a given + // optimization level by default. + static ArithmeticOptimizerOptions Default( + RewriterConfig::Toggle opt_level) { + return {/*enable_add_to_addn_combining*/ true}; + } + }; + // Returns true is a node with given name and the optimizer prefix already // exists. string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const; @@ -97,13 +117,14 @@ class ArithmeticOptimizer : public GraphOptimizer { SetVector<NodeDef*>* nodes_to_simplify); RewriterConfig::Toggle opt_level_; + ArithmeticOptimizerOptions options_; - bool fetch_nodes_known_; + bool fetch_nodes_known_ = false; std::unordered_set<string> nodes_to_preserve_; std::unique_ptr<NodeMap> node_map_; FrameMap frame_map_; std::unique_ptr<GraphProperties> graph_properties_; - GraphDef* optimized_graph_; // Not owned. + GraphDef* optimized_graph_ = nullptr; // Not owned. }; } // end namespace grappler |