aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/meta_optimizer.h
blob: 35d6a4559bb311738758086ad24752e82ffbcdcf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_

#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"

namespace tensorflow {
namespace grappler {

// Run the other grappler optimizers based on the specified rewriter config.
class MetaOptimizer : public GraphOptimizer {
 public:
  MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg);
  ~MetaOptimizer();

  string name() const override { return "meta_optimizer"; };

  Status Optimize(Cluster* cluster, const GrapplerItem& item,
                  GraphDef* optimized_graph) override;

  void PrintResult();

  void Feedback(Cluster* cluster, const GrapplerItem& item,
                const GraphDef& optimized_graph, double result) override;

 private:
  std::unique_ptr<GraphOptimizer> MakeNewOptimizer(
      const string& optimizer) const;

  // Initialize active optimizers from RewriterConfig toggles.
  Status InitializeOptimizers(
      std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
  // Initialize active optimizers from RewriterConfig optimizer names.
  Status InitializeOptimizersByName(
      std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
  // Initialize active optimizers from RewriterConfig.custom_optimizers.
  Status InitializeCustomGraphOptimizers(
      const std::set<string>& pre_initialized_optimizers,
      std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
  // Returns the config for a custom graph optimizer. Null if none was found.
  const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig(
      const string& name) const;

  // Run optimization pass over a single GrapplerItem. Meta optimizer might run
  // multiple such passes: 1) for the main graph 2) for the function library
  Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
                       GraphDef* optimized_graph);

  // Run optimization passes over the main graph and for functions in the
  // function library.
  Status OptimizeMainGraphAndFunctionLibrary(Cluster* cluster,
                                             const GrapplerItem& item,
                                             GraphDef* optimized_graph);

  DeviceBase* const cpu_device_;  // may be NULL
  RewriterConfig cfg_;

  // Thread pool used for launching optimizers asynchronously.
  std::unique_ptr<thread::ThreadPool> thread_pool_;

  struct OptimizerResult {
    string optimizer_name;
    string result;
  };

  struct GraphOptimizationResult {
    explicit GraphOptimizationResult(const string& id) : id(id) {}
    string id;
    std::vector<OptimizerResult> results;
  };

  Status RunOptimizer(GraphOptimizer* optimizer, Cluster* cluster,
                      GrapplerItem* optimized_item, GraphDef* optimized_graph,
                      GraphOptimizationResult* optimization_result);

  std::vector<GraphOptimizationResult> optimization_results_;
};

bool MetaOptimizerEnabled(const RewriterConfig& cfg);

// Run the meta optimizer.
//
// If <cpu_device> is non-null, it is the device to be used for executing ops
// during constant folding; if NULL, a new device is created for doing constant
// folding. For performance, it is recommended to pass in an existing cpu_device
// when possible.
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
                        DeviceBase* cpu_device, Cluster* cluster,
                        GraphDef* optimized_graph);

}  // namespace grappler
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_