aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradient_checker.py
blob: 12afcd0b517d5e85112c067ccaca5693e5a4e231 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Gradient checker for any ops, graphs.

The gradient checker verifies numerically that an op/graph properly
computes the gradients
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export


def _product(t):
  if isinstance(t, int):
    return t
  else:
    y = 1
    for x in t:
      y *= x
    return y


def _extra_feeds(extra_feed_dict, new_feeds):
  if not extra_feed_dict:
    return new_feeds
  r = {}
  r.update(extra_feed_dict)
  r.update(new_feeds)
  return r


def _compute_theoretical_jacobian(x, x_shape, x_data, dy, dy_shape, dx,
                                  extra_feed_dict):
  """Computes the theoretical Jacobian for dy/dx.

  Computes the theoretical Jacobian using the ops generated by
  compute_gradient().

  Args:
    x: the tensor "x".
    x_shape: the dimensions of x as a tuple or an array of ints.
    x_data: a numpy parray as the input data for x
    dy: the tensor "dy".
    dy_shape: the dimensions of dy as a tuple or an array of ints.
    dx: Tensor or IndexedSlices representing dx
    extra_feed_dict: dict that allows fixing specified tensor values
      during the jacobian calculation.

  Returns:
    A 2-d numpy array representing the Jacobian for dy/dx. It has "x_size" rows
    and "dy_size" columns where "x_size" is the number of elements in x and
    "dy_size" is the number of elements in dy.

  Raises:
    ValueError: If `dy` is empty but the gradient is nonzero.
  """
  # Complex vectors are treated as vectors of twice as many reals.
  if x.dtype.is_complex:
    x_shape = tuple(x_shape) + (2,)
  dy_factor = 2 if dy.dtype.is_complex else 1

  # To compute the jacobian, we treat x and y as one-dimensional vectors.
  x_size = _product(x_shape)
  x_val_size = _product(x_shape[1:])  # This is used for sparse gradients
  dy_size = _product(dy_shape) * dy_factor

  # Allocate 2-D Jacobian, with x dimensions smashed into the first
  # dimension and y dimensions smashed into the second.
  jacobian = np.zeros((x_size, dy_size),
                      dtype=x.dtype.real_dtype.as_numpy_dtype)

  # For each of the entry of dy, we set this to be 1 and
  # everything else to be 0 and compute the backprop -- this will give us one
  # one column of the Jacobian matrix.
  dy_data = np.zeros(dy_shape, dtype=dy.dtype.as_numpy_dtype)
  dy_data_flat = dy_data.ravel().view(dy.dtype.real_dtype.as_numpy_dtype)
  sess = ops.get_default_session()
  for col in range(dy_size):
    dy_data_flat[col] = 1
    if isinstance(dx, ops.IndexedSlices):
      backprop_indices, backprop_values = sess.run(
          [dx.indices, dx.values],
          feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data}))
      for i, v in zip(backprop_indices, backprop_values):
        r_begin = i * x_val_size
        r_end = r_begin + x_val_size
        jacobian[r_begin:r_end, col] += v.flat
    else:
      assert isinstance(dx, ops.Tensor), "dx = " + str(dx)
      backprop = sess.run(
          dx, feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data}))
      jacobian[:, col] = backprop.ravel().view(jacobian.dtype)
    dy_data_flat[col] = 0

  # If the output is empty, run the gradients at least once and make sure
  # they produce zeros.
  if not dy_size:
    backprop = sess.run(
        dx, feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data}))
    if backprop.shape != x_data.shape:
      raise ValueError("Empty gradient has wrong shape: expected %s, got %s" %
                       (x_data.shape, backprop.shape))
    if np.any(backprop):
      raise ValueError("Empty tensor with nonzero gradients")

  logging.vlog(1, "Theoretical Jacobian =\n%s", jacobian)
  return jacobian


def _compute_numeric_jacobian(x, x_shape, x_data, y, y_shape, delta,
                              extra_feed_dict):
  """Computes the numeric Jacobian for dy/dx.

  Computes the numeric Jacobian by slightly perturbing the inputs and
  measuring the differences on the output.

  Args:
    x: the tensor "x".
    x_shape: the dimensions of x as a tuple or an array of ints.
    x_data: a numpy array as the input data for x
    y: the tensor "y".
    y_shape: the dimensions of y as a tuple or an array of ints.
    delta: the amount of perturbation we give to the input
    extra_feed_dict: dict that allows fixing specified tensor values
      during the jacobian calculation.

  Returns:
    A 2-d numpy array representing the Jacobian for dy/dx. It has "x_size" rows
    and "y_size" columns where "x_size" is the number of elements in x and
    "y_size" is the number of elements in y.
  """
  # bfloat16 doesn't have enough bits to represent high precision numbers such
  # as delta. Convert to float32 here. Since numeric_jacobian is expected to
  # be the groundtruth to compare against, it shouldn't lose any information.
  if x.dtype == dtypes.bfloat16:
    x = math_ops.cast(x, dtypes.float32)
  if y.dtype == dtypes.bfloat16:
    y = math_ops.cast(y, dtypes.float32)
  if x_data.dtype == dtypes.bfloat16.as_numpy_dtype:
    x_data = x_data.astype(np.float32)

  # To compute the jacobian, we treat x and y as one-dimensional vectors
  x_size = _product(x_shape) * (2 if x.dtype.is_complex else 1)
  y_size = _product(y_shape) * (2 if y.dtype.is_complex else 1)
  x_dtype = x.dtype.real_dtype.as_numpy_dtype
  y_dtype = y.dtype.real_dtype.as_numpy_dtype

  # Make sure we have the right types
  x_data = np.asarray(x_data, dtype=x.dtype.as_numpy_dtype)
  scale = np.asarray(2 * delta, dtype=y_dtype)[()]

  jacobian = np.zeros((x_size, y_size), dtype=x_dtype)
  # For each of the entry of x, we slightly perturbs this by adding and
  # subtracting a delta and then compute difference between the outputs. This
  # will give us one row of the Jacobian matrix.
  for row in range(x_size):
    x_pos = x_data.copy()
    x_neg = x_data.copy()
    x_pos.ravel().view(x_dtype)[row] += delta
    y_pos = y.eval(feed_dict=_extra_feeds(extra_feed_dict, {x: x_pos}))
    x_neg.ravel().view(x_dtype)[row] -= delta
    y_neg = y.eval(feed_dict=_extra_feeds(extra_feed_dict, {x: x_neg}))
    diff = (y_pos - y_neg) / scale
    jacobian[row, :] = diff.ravel().view(y_dtype)

  logging.vlog(1, "Numeric Jacobian =\n%s", jacobian)
  return jacobian


def _compute_dx_and_dy(x, y, y_shape):
  """Returns a node to compute gradient of y wrt x."""
  # We make up a dy so that we can compute the gradients. We don't really use
  # the value of dy -- we will always feed it. We need to add an identity node
  # so that we can always feed it properly. Otherwise, for the Add operation,
  # dx is the same as dy and we cannot fetch the tensor that we are feeding.
  with x.graph.as_default():
    dy_orig = constant_op.constant(1.0, shape=y_shape, dtype=y.dtype)
    dy = array_ops.identity(dy_orig)
  # We compute the gradients for y wrt. x
  grads = gradients.gradients(y, x, dy)
  assert len(grads) == 1
  return grads[0], dy_orig


def _compute_gradient(x,
                      x_shape,
                      dx,
                      y,
                      y_shape,
                      dy,
                      x_init_value=None,
                      delta=1e-3,
                      extra_feed_dict=None):
  """Computes the theoretical and numerical jacobian."""
  t = dtypes.as_dtype(x.dtype)
  allowed_types = [dtypes.float16, dtypes.bfloat16, dtypes.float32,
                   dtypes.float64, dtypes.complex64, dtypes.complex128]
  assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name
  t2 = dtypes.as_dtype(y.dtype)
  assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name

  if x_init_value is not None:
    i_shape = list(x_init_value.shape)
    assert(list(x_shape) == i_shape), "x_shape = %s, init_data shape = %s" % (
        x_shape, i_shape)
    x_data = x_init_value
  else:
    x_data = np.random.random_sample(x_shape).astype(t.as_numpy_dtype)
    if t.is_complex:
      x_data.imag = np.random.random_sample(x_shape)

  jacob_t = _compute_theoretical_jacobian(
      x, x_shape, x_data, dy, y_shape, dx, extra_feed_dict=extra_feed_dict)
  jacob_n = _compute_numeric_jacobian(
      x, x_shape, x_data, y, y_shape, delta, extra_feed_dict=extra_feed_dict)
  return jacob_t, jacob_n


def _compute_gradient_list(x,
                           x_shape,
                           y,
                           y_shape,
                           x_init_value=None,
                           delta=1e-3,
                           init_targets=None,
                           extra_feed_dict=None):
  """Compute gradients for a list of x values."""
  assert isinstance(x, list)
  dx, dy = zip(*[_compute_dx_and_dy(xi, y, y_shape) for xi in x])

  if init_targets is not None:
    assert isinstance(init_targets, (list, tuple))
    for init in init_targets:
      init.run()
  if x_init_value is None:
    x_init_value = [None] * len(x)
  ret = [_compute_gradient(xi, x_shapei, dxi, y, y_shape, dyi, x_init_valuei,
                           delta, extra_feed_dict=extra_feed_dict)
         for xi, x_shapei, dxi, dyi, x_init_valuei in zip(x, x_shape, dx, dy,
                                                          x_init_value)]
  return ret


@tf_export("test.compute_gradient")
def compute_gradient(x,
                     x_shape,
                     y,
                     y_shape,
                     x_init_value=None,
                     delta=1e-3,
                     init_targets=None,
                     extra_feed_dict=None):
  """Computes and returns the theoretical and numerical Jacobian.

  If `x` or `y` is complex, the Jacobian will still be real but the
  corresponding Jacobian dimension(s) will be twice as large.  This is required
  even if both input and output is complex since TensorFlow graphs are not
  necessarily holomorphic, and may have gradients not expressible as complex
  numbers.  For example, if `x` is complex with shape `[m]` and `y` is complex
  with shape `[n]`, each Jacobian `J` will have shape `[m * 2, n * 2]` with

      J[:m, :n] = d(Re y)/d(Re x)
      J[:m, n:] = d(Im y)/d(Re x)
      J[m:, :n] = d(Re y)/d(Im x)
      J[m:, n:] = d(Im y)/d(Im x)

  Args:
    x: a tensor or list of tensors
    x_shape: the dimensions of x as a tuple or an array of ints. If x is a list,
    then this is the list of shapes.
    y: a tensor
    y_shape: the dimensions of y as a tuple or an array of ints.
    x_init_value: (optional) a numpy array of the same shape as "x"
      representing the initial value of x. If x is a list, this should be a list
      of numpy arrays.  If this is none, the function will pick a random tensor
      as the initial value.
    delta: (optional) the amount of perturbation.
    init_targets: list of targets to run to initialize model params.
      TODO(mrry): remove this argument.
    extra_feed_dict: dict that allows fixing specified tensor values
      during the Jacobian calculation.

  Returns:
    Two 2-d numpy arrays representing the theoretical and numerical
    Jacobian for dy/dx. Each has "x_size" rows and "y_size" columns
    where "x_size" is the number of elements in x and "y_size" is the
    number of elements in y. If x is a list, returns a list of two numpy arrays.
  """
  if extra_feed_dict is None:
    extra_feed_dict = {}

  if isinstance(x, list):
    return _compute_gradient_list(x, x_shape, y, y_shape, x_init_value, delta,
                                  init_targets, extra_feed_dict=extra_feed_dict)
  else:
    if init_targets is not None:
      assert isinstance(init_targets, (list, tuple))
      for init in init_targets:
        init.run()
    dx, dy = _compute_dx_and_dy(x, y, y_shape)
    ret = _compute_gradient(x, x_shape, dx, y, y_shape, dy, x_init_value, delta,
                            extra_feed_dict=extra_feed_dict)
    return ret


@tf_export("test.compute_gradient_error")
def compute_gradient_error(x,
                           x_shape,
                           y,
                           y_shape,
                           x_init_value=None,
                           delta=1e-3,
                           init_targets=None,
                           extra_feed_dict=None):
  """Computes the gradient error.

  Computes the maximum error for dy/dx between the computed Jacobian and the
  numerically estimated Jacobian.

  This function will modify the tensors passed in as it adds more operations
  and hence changing the consumers of the operations of the input tensors.

  This function adds operations to the current session. To compute the error
  using a particular device, such as a GPU, use the standard methods for
  setting a device (e.g. using with sess.graph.device() or setting a device
  function in the session constructor).

  Args:
    x: a tensor or list of tensors
    x_shape: the dimensions of x as a tuple or an array of ints. If x is a list,
    then this is the list of shapes.
    y: a tensor
    y_shape: the dimensions of y as a tuple or an array of ints.
    x_init_value: (optional) a numpy array of the same shape as "x"
      representing the initial value of x. If x is a list, this should be a list
      of numpy arrays.  If this is none, the function will pick a random tensor
      as the initial value.
    delta: (optional) the amount of perturbation.
    init_targets: list of targets to run to initialize model params.
    extra_feed_dict: dict that allows fixing specified tensor values
      during the Jacobian calculation.

  Returns:
    The maximum error in between the two Jacobians.
  """
  grad = compute_gradient(x, x_shape, y, y_shape, x_init_value, delta,
                          init_targets, extra_feed_dict=extra_feed_dict)
  if isinstance(grad, tuple):
    grad = [grad]
  error = 0
  for j_t, j_n in grad:
    if j_t.size or j_n.size:  # Handle zero size tensors correctly
      error = np.maximum(error, np.fabs(j_t - j_n).max())
  return error