aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
blob: cf093c6f17b45839156dae0d06ca2fc7e5e2f3c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_

#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"

namespace toco {

// Gets the target quantized data type of an array based on the fake quant op.
// For example, if the num_bits is 8 the data type will be kUint8.
bool InferQuantizedDataTypeFromFakeQuant(
    const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type);

// Gets the min/max numerical range for the given quantized data type.
// For example, kUint8 will return [0,255].
// Returns true if the ranges were set and false if the type is not quantized.
bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
                                        double* out_min_value,
                                        double* out_max_value);

// Returns the quantized data type of an array, falling back to the provided
// default data type.
ArrayDataType GetQuantizedDataType(const Array& array,
                                   ArrayDataType default_type);

// Chooses the quantization params for a given array and a given target
// quantized data type (which may not be the array's current data type).
void ChooseQuantizationParamsForArrayAndQuantizedDataType(
    const Array& array, ArrayDataType quantized_data_type,
    QuantizationParams* quantization_params);

// Quantizes an array by setting its data type and (if constant) quantizing
// all values in the array.
void QuantizeArray(GraphTransformation* transformation, Model* model,
                   const string& name, ArrayDataType quantized_data_type,
                   const QuantizationParams& quantization_params);

// Returns true if the given array, when quantized, contains only values between
// the provided clamp min/max.
// Either clamp_min or clamp_max may be +/-infinity to indicate that the value
// is unbounded on that side.
bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
                                 const Array& array, double clamp_min,
                                 double clamp_max);

}  // namespace toco

#endif  // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_