aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
blob: 4d213b3f9cb930007096dbdd06b1981e9bab2c32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_

#include <cstddef>
#include <initializer_list>
#include <unordered_set>
#include <vector>

#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"

namespace toco {

class GraphTransformation {
 public:
  virtual bool Run(Model* model, std::size_t op_index) = 0;
  virtual const char* Name() const = 0;
  virtual ~GraphTransformation() {}
  // Returns the list of messages that this graph transformation
  // generated since ClearMessages() was called.
  const std::vector<string>& Messages() const { return messages_; }
  // Clears the list of messages; should be called after every
  // run of this graph transformation.
  void ClearMessages() { return messages_.clear(); }
  // Adds a message; normally only called by the graph transformation
  // itself during its run (this function could be protected).
  template <typename... Args>
  void AddMessageF(const char* format, const Args&... args) {
    return messages_.push_back(toco::port::StringF(format, args...));
  }

 protected:
  GraphTransformation() {}

  // List of messages generated by this graph transformation.
  std::vector<string> messages_;

 private:
  GraphTransformation(const GraphTransformation& other) = delete;
  GraphTransformation(const GraphTransformation&& other) = delete;
};

class GraphTransformationsSet {
 public:
  // The choice of a container with fully-specified iteration order
  // ensures that graph transformations are always run in the same order,
  // which avoids having toco randomly fail or produce different results
  // depending on the toolchain. Ideally success/results should be independent
  // of the order in which graph transformations are run, but that's
  // unfortunately not currently guaranteed to be the case.
  using TransformationsContainer =
      std::vector<std::unique_ptr<GraphTransformation>>;

  GraphTransformationsSet() {}
  GraphTransformationsSet(
      const std::initializer_list<GraphTransformation*> transformations) {
    for (GraphTransformation* t : transformations) {
      Add(t);
    }
  }
  void Add(GraphTransformation* transformation) {
    const string& name = transformation->Name();
    CHECK(!names_.count(name));
    names_.insert(name);
    transformations_.emplace_back(transformation);
  }
  TransformationsContainer::const_iterator begin() const {
    return transformations_.begin();
  }
  TransformationsContainer::const_iterator end() const {
    return transformations_.end();
  }
  bool empty() const { return transformations_.empty(); }

 private:
  GraphTransformationsSet(const GraphTransformationsSet& other) = delete;
  GraphTransformationsSet(const GraphTransformationsSet&& other) = delete;
  std::vector<std::unique_ptr<GraphTransformation>> transformations_;
  // Names of transformations in the set. Only used to guard against dupes.
  std::unordered_set<string> names_;
};

// Run the given list of graph transformations on the model.
// The message is only for logging purposes.
// The transformations is a rvalue reference, indicating that
// nothing else will use these pointers. The user is supposed to
// construct GraphTransformation objects by using 'new', pass us
// the resulting raw pointers, and this RunGraphTransformations
// takes care of delete'ing these pointers.
void RunGraphTransformations(Model* model, const string& message,
                             const GraphTransformationsSet& transformations);

#define DECLARE_GRAPH_TRANSFORMATION(GTName)               \
  class GTName : public GraphTransformation {              \
   public:                                                 \
    bool Run(Model* model, std::size_t op_index) override; \
    const char* Name() const override { return #GTName; }  \
  };

// List of all graph transformations
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialPackToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedMinMax)
DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp)
DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
DECLARE_GRAPH_TRANSFORMATION(ReadArrayMinmaxAndNarrowRangeFromFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary)
DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose)
DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
DECLARE_GRAPH_TRANSFORMATION(UnrollBatchMatMul)
DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveReduceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveReshapeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSelect)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTile)
DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero)
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights)
DECLARE_GRAPH_TRANSFORMATION(ResolveFakeQuantArgsFromVars)
DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes)

class PropagateDefaultMinMax : public GraphTransformation {
 public:
  bool Run(Model* model, std::size_t op_index) override;
  const char* Name() const override { return "PropagateDefaultMinMax"; }

  bool has_any_ranges_defined() const { return !type_ranges_.empty(); }
  void DefineTypeRange(ArrayDataType data_type, double min, double max) {
    MinMax minmax;
    minmax.min = min;
    minmax.max = max;
    type_ranges_.emplace_back(data_type, minmax);
  }

 private:
  bool SetArrayMinMax(const string& array_name, Array* array);
  std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
};

class RemoveTrivialReshape : public GraphTransformation {
 public:
  bool Run(Model* model, std::size_t op_index) override;
  const char* Name() const override { return "RemoveTrivialReshape"; }
  bool treat_expand_dims_as_trivial() const {
    return treat_expand_dims_as_trivial_;
  }
  void set_treat_expand_dims_as_trivial(bool val) {
    treat_expand_dims_as_trivial_ = val;
  }

 private:
  bool treat_expand_dims_as_trivial_ = false;
};

class ResolveConstantFakeQuant : public GraphTransformation {
 public:
  bool Run(Model* model, std::size_t op_index) override;
  const char* Name() const override { return "ResolveConstantFakeQuant"; }

  // True if the num_bits should adjust the final data type.
  bool propagate_fake_quant_num_bits() const {
    return propagate_fake_quant_num_bits_;
  }
  void set_propagate_fake_quant_num_bits(bool val) {
    propagate_fake_quant_num_bits_ = val;
  }

 private:
  bool propagate_fake_quant_num_bits_ = false;
};

class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
 public:
  bool Run(Model* model, std::size_t op_index) override;
  const char* Name() const override {
    return "EnsureUint8WeightsSafeForFastInt8Kernels";
  }
  bool allow_nudging_weights() const { return allow_nudging_weights_; }
  void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; }

  bool has_default_ranges_flag() const { return has_default_ranges_flag_; }
  void set_has_default_ranges_flag(bool val) { has_default_ranges_flag_ = val; }

 private:
  bool allow_nudging_weights_ = false;
  bool has_default_ranges_flag_ = false;
};

class IdentifyDilatedConv : public GraphTransformation {
 public:
  bool Run(Model* model, std::size_t op_index) override;
  const char* Name() const override { return "IdentifyDilatedConv"; }
  bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
  void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }

 private:
  bool identify_depthwise_conv_ = true;
};

#undef DECLARE_GRAPH_TRANSFORMATION

}  // end namespace toco

#endif  // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_