diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/loop_optimizer.h')
-rw-r--r-- | tensorflow/core/grappler/optimizers/loop_optimizer.h | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index 106d4628ae..c1b0321e4e 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -17,13 +17,17 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ #include <unordered_set> +#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace grappler { +constexpr char kLoopOptimizer[] = "LoopOptimizer"; + class LoopOptimizer : public GraphOptimizer { public: LoopOptimizer() : opt_level_(RewriterConfig::ON) {} @@ -40,7 +44,29 @@ class LoopOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: + Status LoopInvariantNodeMotion(); + Status FindInvariantNodes(NodeDef* node); + Status RevertInvariantNodes(); + Status MoveInvariantNodes(const int frame_id); + Status LINMHandleInvariantNode(NodeDef* node, const int num_outputs, + const int frame_id); + Status LINMHandleConst(NodeDef* node, const int num_outputs, + const int frame_id); + Status LINMHandleInvariantEnter(NodeDef* node, const int num_outputs); + + std::map<NodeDef*, int> invariant_nodes_; + std::set<int> empty_set_; + std::map<int, std::set<int>> frame_children_; + std::map<int, int> frame_parent_; + std::map<int, const NodeDef*> loop_cond_; + std::map<int, std::vector<NodeDef*>> invariant_enters_; + int new_enter_id_; RewriterConfig::Toggle opt_level_; + + std::unique_ptr<NodeMap> node_map_; + FrameMap frame_map_; + std::unique_ptr<GraphProperties> graph_properties_; + GraphDef* optimized_graph_; // Not owned. }; } // end namespace grappler |