aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac/python/ops/optimizer.py
blob: b7f63d8d94a7a427eb57afefeda3939f0c530f8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The KFAC optimizer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint disable=long-line
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
from tensorflow.contrib.kfac.python.ops import estimator as est
# pylint enable=long-line

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import gradient_descent


class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
  """The KFAC Optimizer (https://arxiv.org/abs/1503.05671)."""

  def __init__(self,
               learning_rate,
               cov_ema_decay,
               damping,
               layer_collection,
               var_list=None,
               momentum=0.9,
               momentum_type="regular",
               norm_constraint=None,
               name="KFAC",
               estimation_mode="gradients",
               colocate_gradients_with_ops=True,
               batch_size=None,
               placement_strategy=None,
               **kwargs):
    """Initializes the KFAC optimizer with the given settings.

    Args:
      learning_rate: The base learning rate for the optimizer.  Should probably
          be set to 1.0 when using momentum_type = 'qmodel', but can still be
          set lowered if desired (effectively lowering the trust in the
          quadratic model.)
      cov_ema_decay: The decay factor used when calculating the covariance
          estimate moving averages.
      damping: The damping factor used to stabilize training due to errors in
          the local approximation with the Fisher information matrix, and to
          regularize the update direction by making it closer to the gradient.
          If damping is adapted during training then this value is used for
          initializing damping variable.
          (Higher damping means the update looks more like a standard gradient
          update - see Tikhonov regularization.)
      layer_collection: The layer collection object, which holds the fisher
          blocks, kronecker factors, and losses associated with the
          graph.  The layer_collection cannot be modified after KfacOptimizer's
          initialization.
      var_list: Optional list or tuple of variables to train. Defaults to the
          list of variables collected in the graph under the key
          `GraphKeys.TRAINABLE_VARIABLES`.
      momentum: The momentum decay constant to use. Only applies when
          momentum_type is 'regular' or 'adam'. (Default: 0.9)
      momentum_type: The type of momentum to use in this optimizer, one of
          'regular', 'adam', or 'qmodel'. (Default: 'regular')
      norm_constraint: float or Tensor. If specified, the update is scaled down
          so that its approximate squared Fisher norm v^T F v is at most the
          specified value. May only be used with momentum type 'regular'.
          (Default: None)
      name: The name for this optimizer. (Default: 'KFAC')
      estimation_mode: The type of estimator to use for the Fishers.  Can be
          'gradients', 'empirical', 'curvature_propagation', or 'exact'.
          (Default: 'gradients'). See the doc-string for FisherEstimator for
          more a more detailed description of these options.
      colocate_gradients_with_ops: Whether we should request gradients we
          compute in the estimator be colocated with their respective ops.
          (Default: True)
      batch_size: The size of the mini-batch. Only needed when momentum_type
          == 'qmodel' or when automatic adjustment is used.  (Default: None)
      placement_strategy: string, Device placement strategy used when creating
        covariance variables, covariance ops, and inverse ops.
        (Default: `None`)
      **kwargs: Arguments to be passesd to specific placement
        strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.

    Raises:
      ValueError: If the momentum type is unsupported.
      ValueError: If clipping is used with momentum type other than 'regular'.
      ValueError: If no losses have been registered with layer_collection.
      ValueError: If momentum is non-zero and momentum_type is not 'regular'
          or 'adam'.
    """
    # Parameters to be passed to the Fisher estimator:
    self._variables = var_list or tf_variables.trainable_variables
    self._cov_ema_decay = cov_ema_decay
    self._layers = layer_collection
    self._estimation_mode = estimation_mode
    self._colocate_gradients_with_ops = colocate_gradients_with_ops

    # The below parameters are required only if damping needs to be adapated.
    # These parameters can be set by calling
    # set_damping_adaptation_params() explicitly.
    self._damping_adaptation_decay = 0.95
    self._damping_adaptation_interval = 5
    # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
    self._omega = (
        self._damping_adaptation_decay**self._damping_adaptation_interval)
    self._adapt_damping = False
    self._min_damping = 1e-5
    self._prev_train_batch = None
    self._is_chief = False
    self._loss_fn = None
    self._damping_constant = damping
    self._damping = None
    self._rho = None
    self._prev_loss = None
    self._q_model_change = None
    self._update_damping_op = None

    momentum_type = momentum_type.lower()
    legal_momentum_types = ["regular", "adam", "qmodel"]

    if momentum_type not in legal_momentum_types:
      raise ValueError("Unsupported momentum type {}. Must be one of {}."
                       .format(momentum_type, legal_momentum_types))
    if momentum_type != "regular" and norm_constraint is not None:
      raise ValueError("Update clipping is only supported with momentum "
                       "type 'regular'.")
    if momentum_type not in ["regular", "adam"] and momentum != 0:
      raise ValueError("Momentum must be unspecified if using a momentum_type "
                       "other than 'regular' or 'adam'.")

    # Extra parameters of the optimizer
    self._momentum = momentum
    self._momentum_type = momentum_type
    self._norm_constraint = norm_constraint
    self._batch_size = batch_size
    self._placement_strategy = placement_strategy

    with variable_scope.variable_scope(name):
      self._fisher_est = est.make_fisher_estimator(
          placement_strategy=placement_strategy,
          variables=self._variables,
          cov_ema_decay=self._cov_ema_decay,
          damping=self.damping,
          layer_collection=self._layers,
          exps=(-1,),
          estimation_mode=self._estimation_mode,
          colocate_gradients_with_ops=self._colocate_gradients_with_ops,
          **kwargs)

    super(KfacOptimizer, self).__init__(learning_rate, name=name)

  def set_damping_adaptation_params(self,
                                    is_chief,
                                    prev_train_batch,
                                    loss_fn,
                                    min_damping=1e-5,
                                    damping_adaptation_decay=0.99,
                                    damping_adaptation_interval=5):
    """Sets parameters required to adapt damping during training.

    When called, enables damping adaptation according to the Levenberg-Marquardt
    style rule described in Section 6.5 of "Optimizing Neural Networks with
    Kronecker-factored Approximate Curvature".

    Note that this function creates Tensorflow variables which store a few
    scalars and are accessed by the ops which update the damping (as part
    of the training op returned by the minimize() method).

    Args:
      is_chief: `Boolean`, `True` if the worker is chief.
      prev_train_batch: Training data used to minimize loss in the previous
        step. This will be used to evaluate loss by calling
        `loss_fn(prev_train_batch)`.
      loss_fn: `function` that takes as input training data tensor and returns
        a scalar loss.
      min_damping: `float`(Optional), Minimum value the damping parameter
        can take. Default value 1e-5.
      damping_adaptation_decay: `float`(Optional), The `damping` parameter is
        multiplied by the `damping_adaptation_decay` every
        `damping_adaptation_interval` number of iterations. Default value 0.99.
      damping_adaptation_interval: `int`(Optional), Number of steps in between
        updating the `damping` parameter. Default value 5.

    Raises:
      ValueError: If `set_damping_adaptation_params` is already called and the
        the `adapt_damping` is `True`.
    """
    if self._adapt_damping:
      raise ValueError("Damping adaptation parameters already set.")

    with variable_scope.variable_scope(self.get_name()):
      self._adapt_damping = True
      self._is_chief = is_chief
      self._prev_train_batch = prev_train_batch
      self._loss_fn = loss_fn
      self._damping_adaptation_decay = damping_adaptation_decay
      self._damping_adaptation_interval = damping_adaptation_interval
      self._omega = (
          self._damping_adaptation_decay**self._damping_adaptation_interval)
      self._min_damping = min_damping

      self._rho = variable_scope.get_variable(
          "rho", shape=(), dtype=dtypes.float32, trainable=False)  # LM ratio.
      self._prev_loss = variable_scope.get_variable(
          "prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
      self._q_model_change = variable_scope.get_variable(
          "q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
      self._damping = variable_scope.get_variable(
          "damping", initializer=self._damping_constant, trainable=False)

  @property
  def variables(self):
    return self._fisher_est.variables

  @property
  def damping(self):
    if self._damping:
      return self._damping
    else:
      return self._damping_constant

  @property
  def damping_adaptation_interval(self):
    return self._damping_adaptation_interval

  def make_vars_and_create_op_thunks(self):
    """Make vars and create op thunks.

    Returns:
      cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
        the list of factors given by the "factors" property.
      inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
        the list of factors given by the "factors" property.
    """
    scope = self.get_name() + "/" + self._fisher_est.name
    return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)

  def create_ops_and_vars_thunks(self):
    """Create thunks that make the ops and vars on demand.

    This function returns 4 lists of thunks: cov_variable_thunks,
    cov_update_thunks, inv_variable_thunks, and inv_update_thunks.

    The length of each list is the number of factors and the i-th element of
    each list corresponds to the i-th factor (given by the "factors" property).

    Note that the execution of these thunks must happen in a certain
    partial order.  The i-th element of cov_variable_thunks must execute
    before the i-th element of cov_update_thunks (and also the i-th element
    of inv_update_thunks).  Similarly, the i-th element of inv_variable_thunks
    must execute before the i-th element of inv_update_thunks.

    TL;DR (oversimplified): Execute the thunks according to the order that
    they are returned.

    Returns:
      cov_variable_thunks: A list of thunks that make the cov variables.
      cov_update_thunks: A list of thunks that make the cov update ops.
      inv_variable_thunks: A list of thunks that make the inv variables.
      inv_update_thunks: A list of thunks that make the inv update ops.
    """
    scope = self.get_name() + "/" + self._fisher_est.name
    return self._fisher_est.create_ops_and_vars_thunks(scope=scope)

  def minimize(self, *args, **kwargs):
    # Should this variable scope encompass everything below?  Or will the super-
    # class make another copy of the same name scope?
    with variable_scope.variable_scope(self.get_name()):
      kwargs["var_list"] = kwargs.get("var_list") or self.variables
      if set(kwargs["var_list"]) != set(self.variables):
        raise ValueError("var_list doesn't match with set of Fisher-estimating "
                         "variables.")
      if self._adapt_damping and self._is_chief:
        global_step = kwargs.get("global_step", None)
        if not global_step:
          raise KeyError("global_step needs to be passed to optimizer.minimize "
                         "if damping parameter is adapted.")
        update_damping_op = self._update_damping(self._prev_train_batch,
                                                 global_step)
        with ops.control_dependencies([update_damping_op]):
          loss = args[0]
          loss_assign_op = state_ops.assign(self._prev_loss, loss)
          train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
          return control_flow_ops.group(loss_assign_op, train_op)
      else:
        return super(KfacOptimizer, self).minimize(*args, **kwargs)

  def compute_gradients(self, *args, **kwargs):
    # args[1] could be our var_list
    if len(args) > 1:
      var_list = args[1]
    else:
      kwargs["var_list"] = kwargs.get("var_list") or self.variables
      var_list = kwargs["var_list"]

    if set(var_list) != set(self.variables):
      raise ValueError("var_list doesn't match with set of Fisher-estimating "
                       "variables.")
    return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)

  def apply_gradients(self, grads_and_vars, *args, **kwargs):
    """Applies gradients to variables.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      *args: Additional arguments for super.apply_gradients.
      **kwargs: Additional keyword arguments for super.apply_gradients.

    Returns:
      An `Operation` that applies the specified gradients.
    """
    # In Python 3, grads_and_vars can be a zip() object which can only be
    # iterated over once. By converting it to a list, we ensure that it can be
    # iterated over more than once.
    grads_and_vars = list(grads_and_vars)

    # Compute step.
    steps_and_vars = self._compute_update_steps(grads_and_vars)

    # Update trainable variables with this step.
    return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args,
                                                      **kwargs)

  def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
    """Computes the squared (approximate) Fisher norm of the updates.

    This is defined as v^T F v, where F is the approximate Fisher matrix
    as computed by the estimator, and v = F^{-1} g, where g is the gradient.
    This is computed efficiently as v^T g.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
        Must be the result of calling `self._fisher_est.multiply_inverse`
        on `grads_and_vars`.

    Returns:
      Scalar representing the squared norm.

    Raises:
      ValueError: if the two list arguments do not contain the same variables,
        in the same order.
    """
    for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars):
      if gvar is not pgvar:
        raise ValueError("The variables referenced by the two arguments "
                         "must match.")
    terms = [
        math_ops.reduce_sum(grad * pgrad)
        for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars)
    ]
    return math_ops.reduce_sum(terms)

  def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
    """Computes the scale factor for the update to satisfy the norm constraint.

    Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
    F is the approximate Fisher matrix, and r is the update vector, i.e.
    -alpha * v, where alpha is the learning rate, and v is the preconditioned
    gradient.

    This is based on Section 5 of Ba et al., Distributed Second-Order
    Optimization using Kronecker-Factored Approximations. Note that they
    absorb the learning rate alpha (which they denote eta_max) into the formula
    for the coefficient, while in our implementation, the rescaling is done
    before multiplying by alpha. Hence, our formula differs from theirs by a
    factor of alpha.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
        Must be the result of calling `self._fisher_est.multiply_inverse`
        on `grads_and_vars`.

    Returns:
      Scalar representing the coefficient which should be applied to the
      preconditioned gradients to satisfy the norm constraint.
    """
    sq_norm_grad = self._squared_fisher_norm(grads_and_vars,
                                             precon_grads_and_vars)
    sq_norm_up = sq_norm_grad * self._learning_rate**2
    return math_ops.minimum(1.,
                            math_ops.sqrt(self._norm_constraint / sq_norm_up))

  def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
    """Rescales the preconditioned gradients to satisfy the norm constraint.

    Rescales the preconditioned gradients such that the resulting update r
    (after multiplying by the learning rate) will satisfy the norm constraint.
    This constraint is that r^T F r <= C, where F is the approximate Fisher
    matrix, and C is the norm_constraint attribute. See Section 5 of
    Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
    Approximations.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
        Must be the result of calling `self._fisher_est.multiply_inverse`
        on `grads_and_vars`.

    Returns:
      List of (rescaled preconditioned gradient, variable) pairs.
    """
    coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
    return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]

  def _compute_prev_updates(self, variables):
    """Computes previous updates as negative velocities scaled by learning rate.

    Args:
      variables: List of variables in the graph that the update will be
          applied to.

    Returns:
      List of previous updates applied to the `variables`.
    """
    return list(
        -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
        for var in variables)

  def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
                                  variables):
    """Compute optimal update hyperparameters from the quadratic model.

    More specifically, if L is the loss we minimize a quadratic approximation
    of L(theta + d) which we denote by qmodel(d) with
    d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where

      qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) .

    Unlike in the KL clipping approach we use the non-approximated quadratic
    model where the curvature matrix C is the true Fisher on the current
    mini-batch (computed without any approximations beyond mini-batch sampling),
    with the usual Tikhonov damping/regularization applied,

      C = F + damping * I

    See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
    the formula.  See Appendix C for a discussion of the trick of using
    a factorized Fisher matrix to more efficiently compute the required
    vector-matrix-vector products.

    Note that the elements of all 4 lists passed to this function must
    be in correspondence with each other.

    Args:
      precon_grads: List of preconditioned gradients.
      prev_updates: List of updates computed at the previous iteration.
      grads: List of gradients.
      variables: List of variables in the graph that the update will be
          applied to. (Note that this function doesn't actually apply the
          update.)

    Returns:
      (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
      quadratic model, and
      qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
                    = qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
    """

    cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
                                                      variables)

    # compute the matrix-vector products with the transposed Fisher factor
    fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
    fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
    batch_size = math_ops.cast(
        self._batch_size, dtype=fft_precon_grads[0].dtype)

    # compute the entries of the 2x2 matrix
    m_11 = (
        _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
        self.damping * _inner_product_list(precon_grads, precon_grads))

    m_21 = (
        _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
        self.damping * _inner_product_list(prev_updates, precon_grads))

    m_22 = (
        _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
        self.damping * _inner_product_list(prev_updates, prev_updates))

    def non_zero_prevupd_case():
      r"""Computes optimal (alpha, mu) given non-zero previous update.

      We solve the full 2x2 linear system. See Martens & Grosse (2015),
      Section 7, definition of $\alpha^*$ and $\mu^*$.

      Returns:
        (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
        the quadratic model, and
        qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
      """
      m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]])

      c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
                                 [_inner_product_list(grads, prev_updates)]])

      sol = -1. * _two_by_two_solve(m, c)
      alpha = sol[0]
      mu = sol[1]
      qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)

      return alpha, mu, qmodel_change

    def zero_prevupd_case():
      r"""Computes optimal (alpha, mu) given all-zero previous update.

      The linear system reduces to 1x1. See Martens & Grosse (2015),
      Section 6.4, definition of $\alpha^*$.

      Returns:
        (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
        quadratic model, and
        qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
      """
      m = m_11
      c = _inner_product_list(grads, precon_grads)

      alpha = -c / m
      mu = 0.0
      qmodel_change = 0.5 * alpha * c

      return alpha, mu, qmodel_change

    return control_flow_ops.cond(
        math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)

  def _assign_q_model_change(self, q_model_change):
    """Assigns `q_model_change` to `self._q_model_change` if damping is adapted.

    Note only the chief worker does the assignment.

    Args:
      q_model_change: Scalar tensor of type `float32`.

    Returns:
      If `adapt_damping` is `True` then returns an assign op, Otherwise returns
      a no_op().
    """
    if self._adapt_damping and self._is_chief:
      q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
    else:
      q_model_assign_op = control_flow_ops.no_op()
    return q_model_assign_op

  def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
                                          precon_grads_and_vars):
    """Wrapper function for `self._compute_qmodel_hyperparams`.

    Constructs a list of preconditioned gradients and variables. Also creates a
    op to asssign the computed q model change to `self._q_model_change`.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      precon_grads_and_vars: List of (preconditioned gradients, variable)
        pairs.

    Returns:
      (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
      the quadratic model, `q_model_assign_op` assigns the computed q model
      change to `self._q_model_change`.
    """
    precon_grads = list(
        precon_grad for (precon_grad, _) in precon_grads_and_vars)
    grads = list(grad for (grad, _) in grads_and_vars)
    variables = list(var for (_, var) in grads_and_vars)
    prev_updates = self._compute_prev_updates(variables)
    # Compute optimal velocity update parameters according to quadratic model
    alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
        precon_grads, prev_updates, grads, variables)

    return alpha, mu, self._assign_q_model_change(q_model_change)

  def _compute_update_steps(self, grads_and_vars):
    """Computes the update steps for the variables given the gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.

    Returns:
      A list of tuple (assign_op ,var) where `assign_op` assigns the update
      steps to `var`.
    """

    if self._momentum_type == "regular":
      # Compute "preconditioned" gradient.
      precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)

      # Apply "KL clipping" if asked for.
      if self._norm_constraint is not None:
        precon_grads_and_vars = self._clip_updates(grads_and_vars,
                                                   precon_grads_and_vars)

      # Update the velocity with this and return it as the step.
      if self._adapt_damping and self._is_chief:
        _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
            grads_and_vars, precon_grads_and_vars)
        with ops.control_dependencies([q_model_assign_op]):
          return self._update_velocities(precon_grads_and_vars, self._momentum)
      else:
        return self._update_velocities(precon_grads_and_vars, self._momentum)
    elif self._momentum_type == "adam":
      # Update velocity.
      velocities_and_vars = self._update_velocities(grads_and_vars,
                                                    self._momentum)
      # Return "preconditioned" velocity vector as the step.
      return self._fisher_est.multiply_inverse(velocities_and_vars)

    elif self._momentum_type == "qmodel":
      # Compute "preconditioned" gradient.
      precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)

      # Compute optimal velocity update parameters according to quadratic model
      alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
          grads_and_vars, precon_grads_and_vars)

      with ops.control_dependencies([q_model_assign_op]):
        return self._update_velocities(
            precon_grads_and_vars, mu, vec_coeff=-alpha)

  def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
    """Updates the velocities of the variables with the given vectors.

    Args:
      vecs_and_vars: List of (vector, variable) pairs.
      decay: How much to decay the old velocity by.  This is often referred to
        as the 'momentum constant'.
      vec_coeff: Coefficient to apply to the vectors before adding them to the
        velocity.

    Returns:
      A list of (velocity, var) indicating the new velocity for each var.
    """

    def _update_velocity(vec, var):
      velocity = self._zeros_slot(var, "velocity", self._name)
      with ops.colocate_with(velocity):
        # NOTE(mattjj): read/modify/write race condition not suitable for async.

        # Compute the new velocity for this variable.
        new_velocity = decay * velocity + vec_coeff * vec

        # Save the updated velocity.
        return (array_ops.identity(velocity.assign(new_velocity)), var)

    # Go through variable and update its associated part of the velocity vector.
    return [_update_velocity(vec, var) for vec, var in vecs_and_vars]

  def _update_damping(self, prev_batch, global_step):
    """Adapts damping parameter. Check KFAC (Section 6.5) for the details.

    The damping parameter is updated according to the Levenberg-Marquardt rule
    every `self._damping_adaptation_interval` iterations.

    Args:
      prev_batch: Tensor or tuple of tensors which can be passed to
        `self._loss_fn` to evaluate loss.
      global_step: `Variable` which keeps track of number of times the training
        variables have been updated.
    Returns:
      A `tf.cond` op which updates the damping parameter.
    """
    def compute_damping():
      """"Adapts damping parameter based on "reduction ratio".

      Reduction ratio captures how closely the quadratic approximation to the
      loss function approximates the actual loss within a trust region. The
      damping update tries to make the damping as small as possible while
      maintaining the property that the quadratic model remains a good local
      approximation to the loss function.

      Returns:
        An Op to assign newly computed damping value to `self._damping`.
      """
      prev_batch_loss = self._loss_fn(prev_batch)
      with ops.control_dependencies([prev_batch_loss]):
        rho_assign = self._rho.assign(
            (prev_batch_loss - self._prev_loss) / self._q_model_change)
        with ops.control_dependencies([rho_assign]):
          new_damping = control_flow_ops.case(
              [(self._rho < 0.25, lambda: self.damping / self._omega),
               (self._rho > 0.75, lambda: self.damping * self._omega)],
              lambda: self.damping)
          with ops.control_dependencies([new_damping]):
            new_damping_min = math_ops.maximum(new_damping, self._min_damping)
            return control_flow_ops.group(self._damping.assign(new_damping_min))

    return control_flow_ops.cond(
        math_ops.equal(
            math_ops.mod(global_step + 1, self._damping_adaptation_interval),
            0), compute_damping, control_flow_ops.no_op)


def _inner_product_list(list1, list2):
  return math_ops.add_n(
      [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)])


def _two_by_two_solve(m, c):
  # it might be better just to crank out the exact formula for 2x2 inverses
  return math_ops.matmul(linalg_ops.matrix_inverse(m), c)