aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
blob: d0113237ce6e43140704672c4cfa5866b7cf49a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"

namespace toco {

namespace {

bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
                         ArrayDataType new_data_type,
                         const MinMax* new_minmax) {
  // Ensure the array ends up in the new type (if it hasn't yet been quantized).
  bool changed = false;
  if (array->final_data_type != new_data_type) {
    array->final_data_type = new_data_type;
    changed = true;
  }

  if (array->minmax && array->quantization_params) {
    // The array is already quantized and has min/max info.
    // As we are changing the data type we need to fix up the existing min/max
    // to the new data type range.

    double old_quantized_min, old_quantized_max;
    CHECK(GetQuantizedDataTypeNumericalRange(
        array->data_type, &old_quantized_min, &old_quantized_max))
        << "Existing data type is not quantized: "
        << ArrayDataTypeName(array->data_type);
    double new_quantized_min, new_quantized_max;
    CHECK(GetQuantizedDataTypeNumericalRange(new_data_type, &new_quantized_min,
                                             &new_quantized_max))
        << "New data type is not quantized: "
        << ArrayDataTypeName(new_data_type);

    // Compute new minmax values.
    double min = (old_quantized_min - array->quantization_params->zero_point) *
                 array->quantization_params->scale;
    double max =
        (old_quantized_max + 1 - array->quantization_params->zero_point) *
        array->quantization_params->scale;
    max = max - 1.0 / (new_quantized_max + 1);

    auto& array_minmax = array->GetOrCreateMinMax();
    transformation->AddMessageF(
        "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min,
        array_minmax.max, ArrayDataTypeName(array->data_type), min, max,
        ArrayDataTypeName(new_data_type));
    array_minmax.min = min;
    array_minmax.max = max;
    ChooseQuantizationParamsForArrayAndQuantizedDataType(
        *array, new_data_type, array->quantization_params.get());
    // Directly change the type as the array was already quantized.
    array->data_type = new_data_type;
    changed = true;
  } else if (!array->quantization_params) {
    // Array has not yet been quantized so we can just set the final data type
    // and assign the new min/max value (if provided).

    if (!array->minmax && new_minmax) {
      transformation->AddMessageF("Forcing new minmax to %g,%g (%s)",
                                  new_minmax->min, new_minmax->max,
                                  ArrayDataTypeName(new_data_type));
      auto& array_minmax = array->GetOrCreateMinMax();
      array_minmax.min = new_minmax->min;
      array_minmax.max = new_minmax->max;
      changed = true;
    }
  }

  return changed;
}

// Returns true if the op blocks our backward recursive data type propagation.
bool DoesOpBlockBackwardPropagation(const Operator& op) {
  switch (op.type) {
    case OperatorType::kConcatenation:
    case OperatorType::kConcat:
    case OperatorType::kConcatV2:
      // Concat shouldn't block propagation, but we do expect that all inputs
      // have the same range.
      return false;
    case OperatorType::kDequantize:
      // Dequantize ops are inserted between the value we care about and the
      // FakeQuant so make sure we move across them.
    case OperatorType::kGather:
      // Gathers need their parameters changed to the appropriate data type.
    case OperatorType::kReshape:
    case OperatorType::kTranspose:
    case OperatorType::kSelect:
    case OperatorType::kTile:
      // Reshapes and transposes don't change values.
      return false;
    default:
      return true;
  }
}

// Returns true if the input of an op blocks our backward recursive data type
// propagation.
bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
  switch (op.type) {
    case OperatorType::kSelect:
      return input_index == 0;
    case OperatorType::kGather:
      // Ignore gather indices.
      return input_index != 0;
      break;
    case OperatorType::kReshape:
    case OperatorType::kTranspose:
      // Ignore reshape/transpose shapes/dimensions.
      return input_index != 0;
    case OperatorType::kTile:
      // Ignore tile multiples.
      return input_index != 0;
    default:
      return false;
  }
}

// Propagates the data type up into the input arrays if they are model inputs
// that may need their type changed. May act recursively if the inputs are
// produced by ops that we can move over (such as Dequantize).
bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation,
                                          Model* model, Operator* op,
                                          ArrayDataType new_data_type,
                                          const MinMax& new_minmax) {
  bool did_change = false;
  for (int input_index = 0; input_index < op->inputs.size(); ++input_index) {
    const auto& input = op->inputs[input_index];
    auto& input_array = model->GetArray(input);
    if (input_array.final_data_type == new_data_type) {
      // Final data type is already - skip.
      continue;
    }

    // Prevent moving into constant param args that we don't want to modify.
    if (DoesOpInputBlockBackwardPropagation(*op, input_index)) {
      continue;
    }

    if (input_array.final_data_type != new_data_type) {
      transformation->AddMessageF(
          "Adjusting input final data type of array %s from %s to %s", input,
          ArrayDataTypeName(input_array.final_data_type),
          ArrayDataTypeName(new_data_type));
      did_change |= ChangeArrayDataType(transformation, &input_array,
                                        new_data_type, &new_minmax);

      // Walk up into all ops producing the inputs to this op.
      for (auto& producing_op : model->operators) {
        if (!DoesOpBlockBackwardPropagation(*producing_op)) {
          for (const auto& output : producing_op->outputs) {
            if (input == output) {
              did_change |= RecursivelyBackwardPropagateDataType(
                  transformation, model, producing_op.get(), new_data_type,
                  new_minmax);
            }
          }
        }
      }
    }
  }
  return did_change;
}

// Returns true if the op blocks our forward recursive data type propagation.
bool DoesOpBlockForwardPropagation(const Operator& op) {
  switch (op.type) {
    case OperatorType::kFakeQuant:
      // Always stop at another FakeQuant, as it will likely have different
      // parameters.
      return true;
    default:
      return false;
  }
}

// Recurses down the graph setting the data type of all arrays until an operator
// that blocks propagation (like another FakeQuant) or a final_data_type is
// already specified.
bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
                                         Model* model, Operator* op,
                                         ArrayDataType new_data_type) {
  bool did_change = false;
  for (const auto& output : op->outputs) {
    auto& output_array = model->GetArray(output);
    if (output_array.final_data_type == new_data_type) {
      // Final data type is already - skip.
      continue;
    }

    if (output_array.final_data_type == ArrayDataType::kNone ||
        output_array.final_data_type != new_data_type) {
      transformation->AddMessageF(
          "Adjusting output final data type of array %s from %s to %s", output,
          ArrayDataTypeName(output_array.final_data_type),
          ArrayDataTypeName(new_data_type));
      did_change |= ChangeArrayDataType(transformation, &output_array,
                                        new_data_type, nullptr);

      // Walk down into all ops consuming the output of this op.
      for (auto& consuming_op : model->operators) {
        if (!DoesOpBlockForwardPropagation(*consuming_op)) {
          for (const auto& input : consuming_op->inputs) {
            if (input == output) {
              did_change |= RecursivelyForwardPropagateDataType(
                  transformation, model, consuming_op.get(), new_data_type);
            }
          }
        }
      }
    }
  }
  return did_change;
}

}  // namespace

// Propagates the num_bits on a FakeQuant operator into the final data types
// of inputs and outputs. For example, if FakeQuant.num_bits==16 then we know
// the output must be int16 and assume all inputs up until the preceding op are
// also 16.
//
// This can be thought of as a bidirectional flood-fill of the num_bits implied
// final_data_type that terminates at other FakeQuant ops (and a few others as
// determined by DoesOpBlockBackwardPropagation/DoesOpBlockForwardPropagation).
// Once all FakeQuant ops have been visted the arrays should all have
// appropriate final_data_types if the source graph was annotated with the
// proper FakeQuant ops.
//
// Annotating a graph requires following a few hard rules:
// - every input MUST have a FakeQuant immediately following it
// - every output MUST have a FakeQuant immediately preceding it
// - important arithmetic ops (such as FullyConnected) SHOULD have a FakeQuant
//   immediately following it
// - all trained weights (RHS of FullyConnected ops, params on Gather ops, etc)
//   MUST have FakeQuants between them and the consuming op
// Additional FakeQuants may be used if desired, especially in areas that may
// suffer from large precision changes - such as between a Softmax and a
// FullyConnected. Only by validating accuracy differences between float
// inference with the FakeQuant ops simulating quantization and the actually
// quantized graph can you be sure the appropriate FakeQuant ops are present.
//
// You can tell if you're missing some FakeQuants by looking for warnings from
// quantize.cc about minmax ranges being determined by the contents of constant
// arrays. This will almost never produce functional models during inference.
//
// As this op may change the data types and ranges of input and output arrays
// downstream tools must also be sure to parse the output model flags to get the
// post-Transform values that may have changed due to this transformation.
//
// This isn't a GraphTransformation in the traditional respect as it affects ops
// outside of the one under transformation. This is primarily so that we can
// utilize the graph traversal and repeated pass system underlying the
// transformation system to exhaustively find all FakeQuant ops. It also gets us
// nice logging and integration with the graphviz video dumping mode.
// In general you should not copy this style of transformation and stick to
// local-only changes as seen in the other transformations.
::tensorflow::Status PropagateFakeQuantNumBits::Run(Model* model,
                                                    std::size_t op_index,
                                                    bool* modified) {
  *modified = false;
  auto it = model->operators.begin() + op_index;
  auto* op = it->get();
  if (op->type != OperatorType::kFakeQuant) {
    return ::tensorflow::Status::OK();
  }
  auto* fakequant_op = static_cast<FakeQuantOperator*>(op);

  ArrayDataType quantized_data_type = ArrayDataType::kNone;
  if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
                                           &quantized_data_type)) {
    AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring",
                LogName(*op), fakequant_op->num_bits);
    return ::tensorflow::Status::OK();
  }
  const auto& final_minmax = *fakequant_op->minmax;

  AddMessageF(
      "Beginning propagation of fake quant %s num_bits=%d min=%g max=%g to %s",
      LogName(*op), fakequant_op->num_bits, final_minmax.min, final_minmax.max,
      ArrayDataTypeName(quantized_data_type));

  bool did_change = false;

  // Propagate the FakeQuant information backward up the graph.
  // This will possibly adjust input arrays or constant types (like Gather).
  did_change |= RecursivelyBackwardPropagateDataType(
      this, model, op, quantized_data_type, final_minmax);

  // Propagate the FakeQuant information forward down the graph.
  // This will possibly adjust output arrays.
  did_change |=
      RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type);

  *modified = did_change;
  return ::tensorflow::Status::OK();
}

}  // namespace toco