aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize/python/quantize_graph.py
blob: 11d052d7f491dc029d1bda9b47364d6e9c880a67 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""API to simulate quantization on a python graph."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.quantize.python import fold_batch_norms
from tensorflow.contrib.quantize.python import quantize
from tensorflow.python.framework import ops


def _create_graph(input_graph=None,
                  is_training=True,
                  weight_bits=8,
                  activation_bits=8,
                  quant_delay=None,
                  freeze_bn_delay=None,
                  scope=None):
  """Rewrites an input_graph in place for simulated quantization.

  The graph has fake quantization ops inserted to simulate the error
  introduced by quantization. Since the graph is transformed in place,
  the expected behavior of previously held references to nodes and tensors may
  change.

  Args:
    input_graph: The tf.Graph to be transformed, if None then defaults to the
      default graph.
    is_training: Whether quantizing training or eval graph.
    weight_bits: Number of bits to use for quantizing weights.
    activation_bits: Number of bits to use for quantizing activations.
    quant_delay: Number of steps after which weights and activations are
      quantized during training.
    freeze_bn_delay: Number of steps after which moving mean and variance are
      frozen and used instead of batch statistics during training.
      freeze_bn_delay should be greater than quant_delay and should correspond
      to the number of steps when training has almost converged
    scope: The scope to be transformed. If it's not None, only the ops which
      are in this scope will be transformed.

  Raises:
    ValueError: If elements contains an element that isn't a tf.Tensor or
      tf.Operation.
  """

  if input_graph is None:
    input_graph = ops.get_default_graph()
  with input_graph.as_default():
    fold_batch_norms.FoldBatchNorms(
        input_graph,
        freeze_batch_norm_delay=freeze_bn_delay,
        is_training=is_training)
    quantize.Quantize(
        input_graph,
        is_training,
        quant_delay=quant_delay,
        weight_bits=weight_bits,
        activation_bits=activation_bits,
        scope=scope)


def create_training_graph(input_graph=None, quant_delay=0):
  """Rewrites a training input_graph in place for simulated quantization.

  Variables added by the rewrite get added to the global variables collection.

  The graph has fake quantization ops inserted to simulate the error
  introduced by quantization. Since the graph is transformed in place,
  the expected behavior of previously held references to nodes and tensors may
  change.

  The default value of quant_delay is suitable for finetuning an already trained
  floating point model (recommended).
  If one wants to train a quantized model from scratch, quant_delay should be
  set to the number of steps it take the floating point model to converge.
  Quantization will be activated at this point and effectively finetune the
  model. If quant_delay is not provided when training from scratch, training can
  often fail.

  Args:
    input_graph: The tf.Graph to be transformed.
    quant_delay: Number of steps after which weights and activations are
      quantized during training.

  Raises:
    ValueError: If elements contains an element that isn't a tf.Tensor or
      tf.Operation.
  """
  # TODO(raghuramank) Need to have freeze_bn_delay be a function of batch size
  # Currently the values below are hardcoded for mobilenetV1 on imagenet
  # Please use the experimental API if you need to tune these values.
  freeze_bn_delay = None

  _create_graph(
      input_graph=input_graph,
      is_training=True,
      quant_delay=quant_delay,
      freeze_bn_delay=freeze_bn_delay)


def create_eval_graph(input_graph=None):
  """Rewrites an eval input_graph in place for simulated quantization.

  Variables added by the rewrite get added to the global variables collection.

  The graph has fake quantization ops inserted to simulate the error
  introduced by quantization. Since the graph is transformed in place,
  the expected behavior of previously held references to nodes and tensors may
  change.

  Args:
    input_graph: The tf.Graph to be transformed, if None then defaults to the
      default graph.

  Raises:
    ValueError: If elements contains an element that isn't a tf.Tensor or
      tf.Operation.
  """
  _create_graph(input_graph=input_graph, is_training=False)


def experimental_create_training_graph(input_graph=None,
                                       weight_bits=8,
                                       activation_bits=8,
                                       quant_delay=0,
                                       freeze_bn_delay=None,
                                       scope=None):
  """Rewrites a training input_graph in place for simulated quantization.

  Variables added by the rewrite get added to the global variables collection.

  This function has additional experimental options not (yet) available to
  create_training_graph. The resulting behavior may be undefined.

  The graph has fake quantization ops inserted to simulate the error
  introduced by quantization. Since the graph is transformed in place,
  the expected behavior of previously held references to nodes and tensors may
  change.

  The default value of quant_delay is suitable for finetuning an already trained
  floating point model (recommended).
  If one wants to train a quantized model from scratch, quant_delay should be
  set to the number of steps it take the floating point model to converge.
  Quantization will be activated at this point and effectively finetune the
  model. If quant_delay is not provided when training from scratch, training can
  often fail.

  Args:
    input_graph: The tf.Graph to be transformed, if None then defaults to the
      default graph.
    weight_bits: Number of bits to use for quantizing weights.
    activation_bits: Number of bits to use for quantizing activations.
    quant_delay: Number of steps after which weights and activations are
      quantized during training.
    freeze_bn_delay: Number of steps after which moving mean and variance are
      frozen and used instead of batch statistics during training.
      freeze_bn_delay should be greater than quant_delay and should correspond
      to when training has almost converged
    scope: The scope to be transformed. If it's not None, only the ops which
      are in this scope will be transformed.

  Raises:
    ValueError: If elements contains an element that isn't a tf.Tensor or
        tf.Operation.
  """

  _create_graph(
      input_graph=input_graph,
      is_training=True,
      weight_bits=weight_bits,
      activation_bits=activation_bits,
      quant_delay=quant_delay,
      freeze_bn_delay=freeze_bn_delay,
      scope=scope)


def experimental_create_eval_graph(input_graph=None,
                                   weight_bits=8,
                                   activation_bits=8,
                                   scope=None):
  """Rewrites an eval input_graph in place for simulated quantization.

  Variables added by the rewrite get added to the global variables collection.

  This function has additional experimental options not (yet) available to
  create_eval_graph. The resulting behavior may be undefined.

  The graph has fake quantization ops inserted to simulate the error
  introduced by quantization. Since the graph is transformed in place,
  the expected behavior of previously held references to nodes and tensors may
  change.

  Args:
    input_graph: The tf.Graph to be transformed, if None then defaults to the
      default graph.
    weight_bits: Number of bits to use for quantizing weights.
    activation_bits: Number of bits to use for quantizing activations.
    scope: The scope to be transformed. If it's not None, only the ops which
      are in this scope will be transformed.

  Raises:
    ValueError: If elements contains an element that isn't a tf.Tensor or
      tf.Operation.
  """
  _create_graph(
      input_graph=input_graph,
      is_training=False,
      weight_bits=weight_bits,
      activation_bits=activation_bits,
      scope=scope)