aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/constant_folding.h
blob: 8593b3e0b878e50c75bab8fa8b3e377aabf8d257 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_

#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"

namespace tensorflow {
namespace grappler {

const char kConstantFoldingConst[] = "ConstantFolding";
const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";

// Constant folding optimization for a graph.
class ConstantFolding : public GraphOptimizer {
 public:
  static Status CreateNodeDef(const string& name, const TensorValue& tensor,
                              NodeDef* node);
  static string AddControlDependency(const string& input_name, GraphDef* graph,
                                     NodeMap* node_map);

  explicit ConstantFolding(DeviceBase* cpu_device);
  ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device);

  ~ConstantFolding() override {}

  string name() const override { return "constant folding"; };

  Status Optimize(Cluster* cluster, const GrapplerItem& item,
                  GraphDef* output) override;

  void Feedback(Cluster* cluster, const GrapplerItem& item,
                const GraphDef& optimize_output, double result) override;

 private:
  string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
  bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const;

  bool IsReallyConstant(const NodeDef& node) const;

  Status MaterializeShapes(const GraphProperties& properties);

  Status MaterializeBroadcastGradientArgs(const NodeDef& node,
                                          const GraphProperties& properties);
  Status MaterializeReductionIndices(NodeDef* node,
                                     const GraphProperties& properties);

  Status MaterializeConstants(const GraphProperties& properties);
  bool IsFoldable(const NodeDef& node) const;

  Status EvaluateNode(const NodeDef& node,
                      const gtl::InlinedVector<TensorValue, 4>& inputs,
                      gtl::InlinedVector<TensorValue, 4>* output) const;

  Status EvaluateOneFoldable(const NodeDef& node,
                             std::vector<NodeDef>* outputs);

  Status FoldNode(NodeDef* node, GraphDef* output_graph);

  bool IsOnes(const NodeDef& node) const;
  bool IsZeros(const NodeDef& node) const;
  void ReplaceOperationWithIdentity(int input_to_forward,
                                    const GraphProperties& properties,
                                    NodeDef* node, GraphDef* graph);
  void ReplaceOperationWithSnapshot(int input_to_forward,
                                    const GraphProperties& properties,
                                    NodeDef* node, GraphDef* graph);
  void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
  Status ReplaceOperationWithConstant(double value,
                                      const GraphProperties& properties,
                                      const TensorShapeProto& shape,
                                      NodeDef* node, GraphDef* graph,
                                      bool* success);
  void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
  Status FoldGraph(GraphDef* output);

  bool IsSimplifiableReduction(const NodeDef& node,
                               const GraphProperties& properties) const;
  bool IsSimplifiableReshape(const NodeDef& node,
                             const GraphProperties& properties) const;
  Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph,
                       GraphProperties* properties);
  Status SimplifyNode(bool use_shape_info, NodeDef* node,
                      GraphDef* optimized_graph, GraphProperties* properties);

  Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
                             GraphDef* output);

  // Applies partial constant folding for Concat which is not commutative.
  // Returns true if the transformation applied successfully.
  bool PartialConcatConstFolding(GraphDef* optimized_graph,
                                 GraphProperties* properties, NodeDef* node);

  // Applies partial constant folding for associative operators AddN and
  // AccumulateNV2. Returns true if the transformation applied successfully.
  bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
                                  GraphProperties* properties, NodeDef* node);

  // Applies partial constant propagation through IdentityN operator.
  // Returns true if the transformation applied successfully.
  bool PartialConstPropThroughIdentityN(NodeDef* node);

  // Pushes down constants on '+' and '*' operators if applicable. Returns true
  // the transformation applied successfully.
  bool ConstantPushDown(NodeDef* node);

  // Aggregate constants present around a conv operator. Returns true if the
  // transformation was applied successfully.
  bool MulConvPushDown(NodeDef* node, const GraphProperties& properties);

  // Strength reduces floating point division by a constant Div(x, const) to
  // multiplication by the reciprocal Mul(x, Reciprocal(const)).
  bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node);

  // Simplifies arithmetic operations with ones or zeros. Returns the status,
  // and updates the success input argument that denotes if any simplification
  // was applied.
  Status SimplifyArithmeticOperations(const GraphProperties& properties,
                                      bool use_shape_info,
                                      GraphDef* optimized_graph, NodeDef* node,
                                      bool* success);

  // Simplifies a Reshape operation to an Identity operation if applicable.
  bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
                       NodeDef* node);

  // Simplifies a Reduction operation to an Identity operation if applicable.
  bool SimplifyReduction(const GraphProperties& properties, NodeDef* node);

  // Switch(x, x) will always feed false to its false branch and true to
  // its true branch. By rewriting the graph a bit, we can propagate these
  // constants down the two output branches, and just use control dependencies
  // to trigger the selected one at runtime. For example,
  //
  //     +------+
  // x-->|Switch|-->a  (in practice there may be multiple consumers of each
  // x-->|      |-->b   output branch.)
  //     +------+
  //
  // Is rewritten as
  //
  //     +------+
  // x-->|Switch|-->Identity--^>Const(false)-->a
  // x-->|      |-->Identity--^>Const(true)-->b
  //     +------+
  bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node);

  // Moves constants past Enter node if applicable.
  bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node);

  // Simplifies Pack operation if applicable.
  bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);

  // Simplifies a Squeeze operation to an Identity operation if applicable.
  bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
                       GraphDef* optimized_graph, NodeDef* node);

  // Simplifies a Pad operation to an Identity operation if applicable.
  Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
                     GraphDef* optimized_graph, NodeDef* node, bool* success);

  // Simplifies a Tile operation to an Identity operation if applicable.
  Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
                      GraphDef* optimized_graph, NodeDef* node, bool* success);

  // Simplifies a StridedSlice operation to an Identity operation if applicable.
  Status SimplifyStridedSlice(const GraphProperties& properties,
                              bool use_shape_info, GraphDef* optimized_graph,
                              NodeDef* node, bool* success);

  // Simplifies a Slice operation to an Identity operation if applicable.
  Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
                       GraphDef* optimized_graph, NodeDef* node, bool* success);

  // Removes Reverse op over dimensions with size 1.
  Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
                       GraphDef* optimized_graph, NodeDef* node, bool* success);

  // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
  bool RemoveRandomShuffle(const GraphProperties& properties,
                           bool use_shape_info, GraphDef* optimized_graph,
                           NodeDef* node);

  // Removes Shuffle or Transpose op over dimensions of size 1.
  Status RemoveShuffleOrTranspose(const GraphProperties& properties,
                                  bool use_shape_info,
                                  GraphDef* optimized_graph, NodeDef* node,
                                  bool* success);

  // Removes Split or SplitV node if possible.
  bool RemoveSplitOrSplitV(const GraphProperties& properties,
                           GraphDef* optimized_graph, NodeDef* node);

  bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
                   GraphDef* optimized_graph, NodeDef* node);

  // Points to an externally provided device or to owned_device_;
  RewriterConfig::Toggle opt_level_;
  DeviceBase* cpu_device_;
  std::unique_ptr<DeviceBase> owned_device_;

  std::unique_ptr<ResourceMgr> resource_mgr_;
  GraphDef* graph_;
  std::unique_ptr<NodeMap> node_map_;
  std::unordered_set<string> nodes_to_preserve_;
  std::unordered_set<string> nodes_whitelist_;
  std::unordered_set<string> feed_nodes_;
  bool has_fetch_;
  bool graph_modified_;
  bool graph_contains_assign_or_inplace_op_;
};

}  // end namespace grappler
}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_