aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/loop_optimizer.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/loop_optimizer.h')
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.h26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h
index 106d4628ae..c1b0321e4e 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h
@@ -17,13 +17,17 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
#include <unordered_set>
+#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
+constexpr char kLoopOptimizer[] = "LoopOptimizer";
+
class LoopOptimizer : public GraphOptimizer {
public:
LoopOptimizer() : opt_level_(RewriterConfig::ON) {}
@@ -40,7 +44,29 @@ class LoopOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
+ Status LoopInvariantNodeMotion();
+ Status FindInvariantNodes(NodeDef* node);
+ Status RevertInvariantNodes();
+ Status MoveInvariantNodes(const int frame_id);
+ Status LINMHandleInvariantNode(NodeDef* node, const int num_outputs,
+ const int frame_id);
+ Status LINMHandleConst(NodeDef* node, const int num_outputs,
+ const int frame_id);
+ Status LINMHandleInvariantEnter(NodeDef* node, const int num_outputs);
+
+ std::map<NodeDef*, int> invariant_nodes_;
+ std::set<int> empty_set_;
+ std::map<int, std::set<int>> frame_children_;
+ std::map<int, int> frame_parent_;
+ std::map<int, const NodeDef*> loop_cond_;
+ std::map<int, std::vector<NodeDef*>> invariant_enters_;
+ int new_enter_id_;
RewriterConfig::Toggle opt_level_;
+
+ std::unique_ptr<NodeMap> node_map_;
+ FrameMap frame_map_;
+ std::unique_ptr<GraphProperties> graph_properties_;
+ GraphDef* optimized_graph_; // Not owned.
};
} // end namespace grappler