aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac/python/ops/estimator.py
blob: 323234c40316757b8bc33564ba8a13b07c8858e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the high-level Fisher estimator class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import numpy as np
import six

from tensorflow.contrib.kfac.python.ops import placement
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest


# The linter is confused.
# pylint: disable=abstract-class-instantiated
def make_fisher_estimator(placement_strategy=None, **kwargs):
  """Creates Fisher estimator instances based on the placement strategy.

  For example if the `placement_strategy` is 'round_robin' then
  `FisherEstimatorRoundRobin` instance is returned.

  Args:
    placement_strategy: `string`, Strategy to be used for placing covariance
      variables, covariance ops and inverse ops. Check
      `placement.FisherEstimatorRoundRobin` for a concrete example.
   **kwargs: Arguments to be passed into `FisherEstimator` class initializer.

  Returns:
    An instance of class which inherits from `FisherEstimator` and the mixin
    which implements specific placement strategy. See,
    `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
    `RoundRobinPlacementMixin`.

  Raises:
    ValueError: If the `placement_strategy` is not equal to 'round_robin'.
  """
  if placement_strategy in [None, "round_robin"]:
    return FisherEstimatorRoundRobin(**kwargs)
  else:
    raise ValueError("Unimplemented vars and ops "
                     "placement strategy : {}".format(placement_strategy))
# pylint: enable=abstract-class-instantiated


@six.add_metaclass(abc.ABCMeta)
class FisherEstimator(object):
  """Fisher estimator class supporting various approximations of the Fisher.

  This is an abstract base class which does not implement a strategy for
  placing covariance variables, covariance update ops and inverse update ops.
  The placement strategies are implemented in `placement.py`. See
  `FisherEstimatorRoundRobin` for example of a concrete subclass with
  a round-robin placement strategy.
  """

  def __init__(self,
               variables,
               cov_ema_decay,
               damping,
               layer_collection,
               exps=(-1,),
               estimation_mode="gradients",
               colocate_gradients_with_ops=True,
               name="FisherEstimator",
               compute_cholesky=False,
               compute_cholesky_inverse=False):
    """Create a FisherEstimator object.

    Args:
      variables: A `list` of variables or `callable` which returns the variables
          for which to estimate the Fisher. This must match the variables
          registered in layer_collection (if it is not None).
      cov_ema_decay: The decay factor used when calculating the covariance
          estimate moving averages.
      damping: float. The damping factor used to stabilize training due to
          errors in the local approximation with the Fisher information matrix,
          and to regularize the update direction by making it closer to the
          gradient. (Higher damping means the update looks more like a standard
          gradient update - see Tikhonov regularization.)
      layer_collection: The layer collection object, which holds the Fisher
          blocks, Kronecker factors, and losses associated with the
          graph.
      exps: List of floats or ints. These represent the different matrix
          powers of the approximate Fisher that the FisherEstimator will be able
          to multiply vectors by. If the user asks for a matrix power other
          one of these (or 1, which is always supported), there will be a
          failure. (Default: (-1,))
      estimation_mode: The type of estimator to use for the Fishers.  Can be
          'gradients', 'empirical', 'curvature_prop', or 'exact'.
          (Default: 'gradients').  'gradients' is the basic estimation approach
          from the original K-FAC paper.  'empirical' computes the 'empirical'
          Fisher information matrix (which uses the data's distribution for the
          targets, as opposed to the true Fisher which uses the model's
          distribution) and requires that each registered loss have specified
          targets. 'curvature_propagation' is a method which estimates the
          Fisher using self-products of random 1/-1 vectors times "half-factors"
          of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
          Finally, 'exact' is the obvious generalization of Curvature
          Propagation to compute the exact Fisher (modulo any additional
          diagonal or Kronecker approximations) by looping over one-hot vectors
          for each coordinate of the output instead of using 1/-1 vectors.  It
          is more expensive to compute than the other three options by a factor
          equal to the output dimension, roughly speaking.
      colocate_gradients_with_ops: Whether we should request gradients be
          colocated with their respective ops. (Default: True)
      name: A string. A name given to this estimator, which is added to the
          variable scope when constructing variables and ops.
          (Default: "FisherEstimator")
      compute_cholesky: Bool. Whether or not the FisherEstimator will be
          able to multiply vectors by the Cholesky factor.
          (Default: False)
      compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
          will be able to multiply vectors by the Cholesky factor inverse.
          (Default: False)
    Raises:
      ValueError: If no losses have been registered with layer_collection.
    """
    self._variables = variables
    self._cov_ema_decay = cov_ema_decay
    self._damping = damping
    self._estimation_mode = estimation_mode
    self._layers = layer_collection
    self._gradient_fns = {
        "gradients": self._get_grads_lists_gradients,
        "empirical": self._get_grads_lists_empirical,
        "curvature_prop": self._get_grads_lists_curvature_prop,
        "exact": self._get_grads_lists_exact
    }
    self._colocate_gradients_with_ops = colocate_gradients_with_ops

    self._made_vars = False
    self._exps = exps
    self._compute_cholesky = compute_cholesky
    self._compute_cholesky_inverse = compute_cholesky_inverse

    self._name = name

  @property
  def variables(self):
    if callable(self._variables):
      return self._variables()
    else:
      return self._variables

  @property
  def damping(self):
    return self._damping

  @property
  def blocks(self):
    """All registered FisherBlocks."""
    return self._layers.get_blocks()

  @property
  def factors(self):
    """All registered FisherFactors."""
    return self._layers.get_factors()

  @property
  def name(self):
    return self._name

  @abc.abstractmethod
  def make_vars_and_create_op_thunks(self, scope=None):
    """Make vars and create op thunks with a specific placement strategy.

    For each factor, all of that factor's cov variables and their associated
    update ops will be placed on a particular device.  A new device is chosen
    for each factor by cycling through list of devices in the cov_devices
    argument. If cov_devices is None then no explicit device placement occurs.

    An analogous strategy is followed for inverse update ops, with the list of
    devices being given by the inv_devices argument.

    Inverse variables on the other hand are not placed on any specific device
    (they will just use the current the device placement context, whatever
    that happens to be).  The idea is that the inverse variable belong where
    they will be accessed most often, which is the device that actually applies
    the preconditioner to the gradient. The user will be responsible for setting
    the device context for this.

    Args:
      scope: A string or None.  If None it will be set to the name of this
        estimator (given by the name property). All variables will be created,
        and all thunks will execute, inside of a variable scope of the given
        name. (Default: None)

    Returns:
      cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
        the list of factors given by the "factors" property.
      inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
        the list of factors given by the "factors" property.
    """
    pass

  def _apply_transformation(self, vecs_and_vars, transform):
    """Applies an block-wise transformation to the corresponding vectors.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.
      transform: A function of the form f(fb, vec), where vec is the vector
          to transform and fb is its corresponding block in the matrix, that
          returns the transformed vector.

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """

    vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)

    trans_vecs = utils.SequenceDict()

    for params, fb in self._layers.fisher_blocks.items():
      trans_vecs[params] = transform(fb, vecs[params])

    return [(trans_vecs[var], var) for _, var in vecs_and_vars]

  def multiply_inverse(self, vecs_and_vars):
    """Multiplies the vecs by the corresponding (damped) inverses of the blocks.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """
    return self.multiply_matpower(-1, vecs_and_vars)

  def multiply(self, vecs_and_vars):
    """Multiplies the vectors by the corresponding (damped) blocks.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """
    return self.multiply_matpower(1, vecs_and_vars)

  def multiply_matpower(self, exp, vecs_and_vars):
    """Multiplies the vecs by the corresponding matrix powers of the blocks.

    Args:
      exp: A float representing the power to raise the blocks by before
        multiplying it by the vector.
      vecs_and_vars: List of (vector, variable) pairs.

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """
    assert exp in self._exps

    fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
    return self._apply_transformation(vecs_and_vars, fcn)

  def multiply_cholesky(self, vecs_and_vars, transpose=False):
    """Multiplies the vecs by the corresponding Cholesky factors.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.
      transpose: Bool. If true the Cholesky factors are transposed before
        multiplying the vecs. (Default: False)

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """
    assert self._compute_cholesky

    fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
    return self._apply_transformation(vecs_and_vars, fcn)

  def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
    """Mults the vecs by the inverses of the corresponding Cholesky factors.

      Note: if you are using Cholesky inverse multiplication to sample from
      a matrix-variate Gaussian you will want to multiply by the transpose.
      Let L be the Cholesky factor of F and observe that

        L^-T * L^-1 = (L * L^T)^-1 = F^-1 .

      Thus we want to multiply by L^-T in order to sample from Gaussian with
      covariance F^-1.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.
      transpose: Bool. If true the Cholesky factor inverses are transposed
        before multiplying the vecs. (Default: False)

    Returns:
      A list of (transformed vector, var) pairs in the same order as
      vecs_and_vars.
    """
    assert self._compute_cholesky_inverse

    fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
    return self._apply_transformation(vecs_and_vars, fcn)

  def _instantiate_factors(self):
    """Instantiates FisherFactors' variables.

    Raises:
      ValueError: If estimation_mode was improperly specified at construction.
    """
    blocks = self.blocks
    tensors_to_compute_grads = [
        block.tensors_to_compute_grads() for block in blocks
    ]

    try:
      grads_lists = self._gradient_fns[self._estimation_mode](
          tensors_to_compute_grads)
    except KeyError:
      raise ValueError("Unrecognized value {} for estimation_mode.".format(
          self._estimation_mode))

    for grads_list, block in zip(grads_lists, blocks):
      block.instantiate_factors(grads_list, self.damping)

  def _check_vars_unmade_and_set_made_flag(self):
    if self._made_vars:
      raise Exception("Already made variables.")
    self._made_vars = True

  def made_vars(self):
    return self._made_vars

  def _register_matrix_functions(self):
    for block in self.blocks:
      for exp in self._exps:
        block.register_matpower(exp)
      if self._compute_cholesky:
        block.register_cholesky()
      if self._compute_cholesky_inverse:
        block.register_cholesky_inverse()

  def _finalize_layer_collection(self):
    self._layers.create_subgraph()
    self._layers.check_registration(self.variables)
    self._instantiate_factors()
    self._register_matrix_functions()

  def create_ops_and_vars_thunks(self, scope=None):
    """Create thunks that make the ops and vars on demand.

    This function returns 4 lists of thunks: cov_variable_thunks,
    cov_update_thunks, inv_variable_thunks, and inv_update_thunks.

    The length of each list is the number of factors and the i-th element of
    each list corresponds to the i-th factor (given by the "factors" property).

    Note that the execution of these thunks must happen in a certain
    partial order.  The i-th element of cov_variable_thunks must execute
    before the i-th element of cov_update_thunks (and also the i-th element
    of inv_update_thunks).  Similarly, the i-th element of inv_variable_thunks
    must execute before the i-th element of inv_update_thunks.

    TL;DR (oversimplified): Execute the thunks according to the order that
    they are returned.

    Args:
      scope: A string or None.  If None it will be set to the name of this
        estimator (given by the name property). All thunks will execute inside
        of a variable scope of the given name. (Default: None)
    Returns:
      cov_variable_thunks: A list of thunks that make the cov variables.
      cov_update_thunks: A list of thunks that make the cov update ops.
      inv_variable_thunks: A list of thunks that make the inv variables.
      inv_update_thunks: A list of thunks that make the inv update ops.
    """
    self._check_vars_unmade_and_set_made_flag()

    self._finalize_layer_collection()

    scope = self.name if scope is None else scope

    cov_variable_thunks = [
        self._create_cov_variable_thunk(factor, scope)
        for factor in self.factors
    ]
    cov_update_thunks = [
        self._create_cov_update_thunk(factor, scope) for factor in self.factors
    ]
    inv_variable_thunks = [
        self._create_inv_variable_thunk(factor, scope)
        for factor in self.factors
    ]
    inv_update_thunks = [
        self._create_inv_update_thunk(factor, scope) for factor in self.factors
    ]

    return (cov_variable_thunks, cov_update_thunks,
            inv_variable_thunks, inv_update_thunks)

  def _create_cov_variable_thunk(self, factor, scope):
    """Constructs a covariance variable thunk for a single FisherFactor."""

    def thunk():
      with variable_scope.variable_scope(scope):
        return factor.instantiate_cov_variables()

    return thunk

  def _create_cov_update_thunk(self, factor, scope):
    """Constructs a covariance update thunk for a single FisherFactor."""

    def thunk():
      with variable_scope.variable_scope(scope):
        return factor.make_covariance_update_op(self._cov_ema_decay)

    return thunk

  def _create_inv_variable_thunk(self, factor, scope):
    """Constructs a inverse variable thunk for a single FisherFactor."""

    def thunk():
      with variable_scope.variable_scope(scope):
        return factor.instantiate_inv_variables()

    return thunk

  def _create_inv_update_thunk(self, factor, scope):
    """Constructs an inverse update thunk for a single FisherFactor."""

    def thunk():
      with variable_scope.variable_scope(scope):
        return control_flow_ops.group(factor.make_inverse_update_ops())

    return thunk

  def _get_grads_lists_gradients(self, tensors):
    # Passing in a list of loss values is better than passing in the sum as
    # the latter creates unnessesary ops on the default device
    grads_flat = gradients_impl.gradients(
        self._layers.eval_losses_on_samples(),
        nest.flatten(tensors),
        colocate_gradients_with_ops=self._colocate_gradients_with_ops)
    grads_all = nest.pack_sequence_as(tensors, grads_flat)
    return tuple((grad,) for grad in grads_all)

  def _get_grads_lists_empirical(self, tensors):
    # Passing in a list of loss values is better than passing in the sum as
    # the latter creates unnecessary ops on the default device
    grads_flat = gradients_impl.gradients(
        self._layers.eval_losses(),
        nest.flatten(tensors),
        colocate_gradients_with_ops=self._colocate_gradients_with_ops)
    grads_all = nest.pack_sequence_as(tensors, grads_flat)
    return tuple((grad,) for grad in grads_all)

  def _get_transformed_random_signs(self):
    transformed_random_signs = []
    for loss in self._layers.losses:
      with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
        transformed_random_signs.append(
            loss.multiply_fisher_factor(
                utils.generate_random_signs(loss.fisher_factor_inner_shape)))
    return transformed_random_signs

  def _get_grads_lists_curvature_prop(self, tensors):
    loss_inputs = list(loss.inputs for loss in self._layers.losses)
    transformed_random_signs = self._get_transformed_random_signs()
    grads_flat = gradients_impl.gradients(
        nest.flatten(loss_inputs),
        nest.flatten(tensors),
        grad_ys=nest.flatten(transformed_random_signs),
        colocate_gradients_with_ops=self._colocate_gradients_with_ops)
    grads_all = nest.pack_sequence_as(tensors, grads_flat)
    return tuple((grad,) for grad in grads_all)

  def _get_grads_lists_exact(self, tensors):
    """No docstring required."""
    # Loop over all coordinates of all losses.
    grads_all = []
    for loss in self._layers.losses:
      with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
        for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
          transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
              index)
          grads_flat = gradients_impl.gradients(
              loss.inputs,
              nest.flatten(tensors),
              grad_ys=transformed_one_hot,
              colocate_gradients_with_ops=self._colocate_gradients_with_ops)
          grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
    return zip(*grads_all)


class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
                                FisherEstimator):
  """Fisher estimator which provides round robin device placement strategy."""
  pass