aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
blob: 9fa6eb7dcd12d7c6474d176198c1e47f1ec6fd4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FisherBlock definitions.

This library contains classes for estimating blocks in a model's Fisher
Information matrix. Suppose one has a model that parameterizes a posterior
distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
Fisher Information matrix is given by,

  $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$

where,

  $$v(x, y, params) = (d / d params) log p(y | x, params)$$

and the expectation is taken with respect to the data's distribution for 'x' and
the model's posterior distribution for 'y',

  x ~ p(x)
  y ~ p(y | x, params)

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import enum  # pylint: disable=g-bad-import-order

import numpy as np
import six

from tensorflow.contrib.kfac.python.ops import fisher_factors
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest

# For blocks corresponding to convolutional layers, or any type of block where
# the parameters can be thought of as being replicated in time or space,
# we want to adjust the scale of the damping by
#   damping /= num_replications ** NORMALIZE_DAMPING_POWER
NORMALIZE_DAMPING_POWER = 1.0

# Methods for adjusting damping for FisherBlocks. See
# compute_pi_adjusted_damping() for details.
PI_OFF_NAME = "off"
PI_TRACENORM_NAME = "tracenorm"
PI_TYPE = PI_TRACENORM_NAME


def set_global_constants(normalize_damping_power=None, pi_type=None):
  """Sets various global constants used by the classes in this module."""
  global NORMALIZE_DAMPING_POWER
  global PI_TYPE

  if normalize_damping_power is not None:
    NORMALIZE_DAMPING_POWER = normalize_damping_power

  if pi_type is not None:
    PI_TYPE = pi_type


def normalize_damping(damping, num_replications):
  """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
  if NORMALIZE_DAMPING_POWER:
    return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
  return damping


def compute_pi_tracenorm(left_cov, right_cov):
  r"""Computes the scalar constant pi for Tikhonov regularization/damping.

  $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
  See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.

  Args:
    left_cov: A LinearOperator object. The left Kronecker factor "covariance".
    right_cov: A LinearOperator object. The right Kronecker factor "covariance".

  Returns:
    The computed scalar constant pi for these Kronecker Factors (as a Tensor).
  """
  # Instead of dividing by the dim of the norm, we multiply by the dim of the
  # other norm. This works out the same in the ratio.
  left_norm = left_cov.trace() * int(right_cov.domain_dimension)
  right_norm = right_cov.trace() * int(left_cov.domain_dimension)
  return math_ops.sqrt(left_norm / right_norm)


def compute_pi_adjusted_damping(left_cov, right_cov, damping):

  if PI_TYPE == PI_TRACENORM_NAME:
    pi = compute_pi_tracenorm(left_cov, right_cov)
    return (damping * pi, damping / pi)

  elif PI_TYPE == PI_OFF_NAME:
    return (damping, damping)


class PackagedFunc(object):
  """A Python thunk with a stable ID.

  Enables stable names for lambdas.
  """

  def __init__(self, func, func_id):
    """Initializes PackagedFunc.

    Args:
      func: a zero-arg Python function.
      func_id: a hashable, function that produces a hashable, or a list/tuple
        thereof.
    """
    self._func = func
    func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
    self._func_id = func_id

  def __call__(self):
    return self._func()

  @property
  def func_id(self):
    """A hashable identifier for this function."""
    return tuple(elt() if callable(elt) else elt for elt in self._func_id)


def _package_func(func, func_id):
  return PackagedFunc(func, func_id)


@six.add_metaclass(abc.ABCMeta)
class FisherBlock(object):
  """Abstract base class for objects modeling approximate Fisher matrix blocks.

  Subclasses must implement register_matpower, multiply_matpower,
  instantiate_factors, tensors_to_compute_grads, and num_registered_towers
  methods.
  """

  def __init__(self, layer_collection):
    self._layer_collection = layer_collection

  @abc.abstractmethod
  def instantiate_factors(self, grads_list, damping):
    """Creates and registers the component factors of this Fisher block.

    Args:
      grads_list: A list gradients (each a Tensor or tuple of Tensors) with
          respect to the tensors returned by tensors_to_compute_grads() that
          are to be used to estimate the block.
      damping: The damping factor (float or Tensor).
    """
    pass

  @abc.abstractmethod
  def register_matpower(self, exp):
    """Registers a matrix power to be computed by the block.

    Args:
      exp: A float representing the power to raise the block by.
    """
    pass

  @abc.abstractmethod
  def register_cholesky(self):
    """Registers a Cholesky factor to be computed by the block."""
    pass

  @abc.abstractmethod
  def register_cholesky_inverse(self):
    """Registers an inverse Cholesky factor to be computed by the block."""
    pass

  def register_inverse(self):
    """Registers a matrix inverse to be computed by the block."""
    self.register_matpower(-1)

  @abc.abstractmethod
  def multiply_matpower(self, vector, exp):
    """Multiplies the vector by the (damped) matrix-power of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
      exp: A float representing the power to raise the block by before
        multiplying it by the vector.

    Returns:
      The vector left-multiplied by the (damped) matrix-power of the block.
    """
    pass

  def multiply_inverse(self, vector):
    """Multiplies the vector by the (damped) inverse of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

    Returns:
      The vector left-multiplied by the (damped) inverse of the block.
    """
    return self.multiply_matpower(vector, -1)

  def multiply(self, vector):
    """Multiplies the vector by the (damped) block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

    Returns:
      The vector left-multiplied by the (damped) block.
    """
    return self.multiply_matpower(vector, 1)

  @abc.abstractmethod
  def multiply_cholesky(self, vector, transpose=False):
    """Multiplies the vector by the (damped) Cholesky-factor of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
      transpose: Bool. If true the Cholesky factor is transposed before
        multiplying the vector. (Default: False)

    Returns:
      The vector left-multiplied by the (damped) Cholesky-factor of the block.
    """
    pass

  @abc.abstractmethod
  def multiply_cholesky_inverse(self, vector, transpose=False):
    """Multiplies vector by the (damped) inverse Cholesky-factor of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
      transpose: Bool. If true the Cholesky factor inverse is transposed
        before multiplying the vector. (Default: False)
    Returns:
      Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
    """
    pass

  @abc.abstractmethod
  def tensors_to_compute_grads(self):
    """Returns the Tensor(s) with respect to which this FisherBlock needs grads.
    """
    pass

  @abc.abstractproperty
  def num_registered_towers(self):
    """Number of towers registered for this FisherBlock.

    Typically equal to the number of towers in a multi-tower setup.
    """
    pass


class FullFB(FisherBlock):
  """FisherBlock using a full matrix estimate (no approximations).

  FullFB uses a full matrix estimate (no approximations), and should only ever
  be used for very low dimensional parameters.

  Note that this uses the naive "square the sum estimator", and so is applicable
  to any type of parameter in principle, but has very high variance.
  """

  def __init__(self, layer_collection, params):
    """Creates a FullFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      params: The parameters of this layer (Tensor or tuple of Tensors).
    """
    self._batch_sizes = []
    self._params = params

    super(FullFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    self._damping_func = _package_func(lambda: damping, (damping,))

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullFactor, (grads_list, self._batch_size))

  def register_matpower(self, exp):
    self._factor.register_matpower(exp, self._damping_func)

  def register_cholesky(self):
    self._factor.register_cholesky(self._damping_func)

  def register_cholesky_inverse(self):
    self._factor.register_cholesky_inverse(self._damping_func)

  def _multiply_matrix(self, matrix, vector, transpose=False):
    vector_flat = utils.tensors_to_column(vector)
    out_flat = matrix.matmul(vector_flat, adjoint=transpose)
    return utils.column_to_tensors(vector, out_flat)

  def multiply_matpower(self, vector, exp):
    matrix = self._factor.get_matpower(exp, self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def multiply_cholesky(self, vector, transpose=False):
    matrix = self._factor.get_cholesky(self._damping_func)
    return self._multiply_matrix(matrix, vector, transpose=transpose)

  def multiply_cholesky_inverse(self, vector, transpose=False):
    matrix = self._factor.get_cholesky_inverse(self._damping_func)
    return self._multiply_matrix(matrix, vector, transpose=transpose)

  def full_fisher_block(self):
    """Explicitly constructs the full Fisher block."""
    return self._factor.get_cov_as_linear_operator().to_dense()

  def tensors_to_compute_grads(self):
    return self._params

  def register_additional_tower(self, batch_size):
    """Register an additional tower.

    Args:
      batch_size: The batch size, used in the covariance estimator.
    """
    self._batch_sizes.append(batch_size)

  @property
  def num_registered_towers(self):
    return len(self._batch_sizes)

  @property
  def _batch_size(self):
    return math_ops.reduce_sum(self._batch_sizes)


@six.add_metaclass(abc.ABCMeta)
class DiagonalFB(FisherBlock):
  """A base class for FisherBlocks that use diagonal approximations."""

  def register_matpower(self, exp):
    # Not needed for this.  Matrix powers are computed on demand in the
    # diagonal case
    pass

  def register_cholesky(self):
    # Not needed for this.  Cholesky's are computed on demand in the
    # diagonal case
    pass

  def register_cholesky_inverse(self):
    # Not needed for this.  Cholesky inverses's are computed on demand in the
    # diagonal case
    pass

  def _multiply_matrix(self, matrix, vector):
    vector_flat = utils.tensors_to_column(vector)
    out_flat = matrix.matmul(vector_flat)
    return utils.column_to_tensors(vector, out_flat)

  def multiply_matpower(self, vector, exp):
    matrix = self._factor.get_matpower(exp, self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def multiply_cholesky(self, vector, transpose=False):
    matrix = self._factor.get_cholesky(self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def multiply_cholesky_inverse(self, vector, transpose=False):
    matrix = self._factor.get_cholesky_inverse(self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def full_fisher_block(self):
    return self._factor.get_cov_as_linear_operator().to_dense()


class NaiveDiagonalFB(DiagonalFB):
  """FisherBlock using a diagonal matrix approximation.

  This type of approximation is generically applicable but quite primitive.

  Note that this uses the naive "square the sum estimator", and so is applicable
  to any type of parameter in principle, but has very high variance.
  """

  def __init__(self, layer_collection, params):
    """Creates a NaiveDiagonalFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      params: The parameters of this layer (Tensor or tuple of Tensors).
    """
    self._params = params
    self._batch_sizes = []

    super(NaiveDiagonalFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    self._damping_func = _package_func(lambda: damping, (damping,))

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))

  def tensors_to_compute_grads(self):
    return self._params

  def register_additional_tower(self, batch_size):
    """Register an additional tower.

    Args:
      batch_size: The batch size, used in the covariance estimator.
    """
    self._batch_sizes.append(batch_size)

  @property
  def num_registered_towers(self):
    return len(self._batch_sizes)

  @property
  def _batch_size(self):
    return math_ops.reduce_sum(self._batch_sizes)


class InputOutputMultiTower(object):
  """Mix-in class for blocks with inputs & outputs and multiple mini-batches."""

  def __init__(self, *args, **kwargs):
    self.__inputs = []
    self.__outputs = []
    super(InputOutputMultiTower, self).__init__(*args, **kwargs)

  def _process_data(self, grads_list):
    """Process data into the format used by the factors.

    This function takes inputs and grads_lists data and processes it into
    one of the formats expected by the FisherFactor classes (depending on
    the value of the global configuration variable TOWER_STRATEGY).

    The initial format of self._inputs is expected to be a list of Tensors
    over towers. Similarly grads_lists is expected to be a list over sources
    of such lists.

    If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single
    tensor (represented as a PartitionedTensor object) equal to the
    concatenation (across towers) of all of the elements of self._inputs. And
    similarly grads_list is formatted into a tuple (over sources) of such
    tensors (also represented as PartitionedTensors).

    If TOWER_STRATEGY is "separate", formatting of inputs and grads_list
    remains unchanged from the initial format (although possibly converting
    from lists into tuples).

    Args:
      grads_list: grads_list in its initial format (see above).

    Returns:
      inputs: self._inputs transformed into the appropriate format (see
        above).
      grads_list: grads_list transformed into the appropriate format (see
        above).

    Raises:
      ValueError: if TOWER_STRATEGY is not one of "separate" or "concat".
    """
    inputs = self._inputs
    # inputs is a list over towers of Tensors
    # grads_list is a list of list with the first index being sources and the
    # second being towers.
    if fisher_factors.TOWER_STRATEGY == "concat":
      # Merge towers together into a PartitionedTensor. We package it in
      # a singleton tuple since the factors will expect a list over towers
      inputs = (utils.PartitionedTensor(inputs),)
      # Do the same for grads_list but preserve leading sources dimension
      grads_list = tuple((utils.PartitionedTensor(grads),)
                         for grads in grads_list)
    elif fisher_factors.TOWER_STRATEGY == "separate":
      inputs = tuple(inputs)
      grads_list = tuple(grads_list)

    else:
      raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                       "'concat' or 'separate'.")

    return inputs, grads_list

  def tensors_to_compute_grads(self):
    """Tensors to compute derivative of loss with respect to."""
    return tuple(self._outputs)

  def register_additional_tower(self, inputs, outputs):
    self._inputs.append(inputs)
    self._outputs.append(outputs)

  @property
  def num_registered_towers(self):
    result = len(self._inputs)
    assert result == len(self._outputs)
    return result

  @property
  def _inputs(self):
    return self.__inputs

  @property
  def _outputs(self):
    return self.__outputs


class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
  """FisherBlock for fully-connected (dense) layers using a diagonal approx.

  Estimates the Fisher Information matrix's diagonal entries for a fully
  connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of
  squares" estimator.

  Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
  into it. We are interested in Fisher(params)[i, i]. This is,

    $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
                         = E[ v(x, y, params)[i] ^ 2 ]$$

  Consider fully connected layer in this model with (unshared) weight matrix
  'w'. For an example 'x' that produces layer inputs 'a' and output
  preactivations 's',

    $$v(x, y, w) = vec( a (d loss / d s)^T )$$

  This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
  to the layer's parameters 'w'.
  """

  def __init__(self, layer_collection, has_bias=False):
    """Creates a FullyConnectedDiagonalFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      has_bias: Whether the component Kronecker factors have an additive bias.
          (Default: False)
    """
    self._has_bias = has_bias

    super(FullyConnectedDiagonalFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedDiagonalFactor,
        (inputs, grads_list, self._has_bias))

    self._damping_func = _package_func(lambda: damping, (damping,))


class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
  """FisherBlock for 2-D convolutional layers using a diagonal approx.

  Estimates the Fisher Information matrix's diagonal entries for a convolutional
  layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
  estimator.

  Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
  into it. We are interested in Fisher(params)[i, i]. This is,

    $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
                         = E[ v(x, y, params)[i] ^ 2 ]$$

  Consider a convoluational layer in this model with (unshared) filter matrix
  'w'. For an example image 'x' that produces layer inputs 'a' and output
  preactivations 's',

    $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$

  where 'loc' is a single (x, y) location in an image.

  This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
  to the layer's parameters 'w'.
  """

  def __init__(self,
               layer_collection,
               params,
               strides,
               padding,
               data_format=None,
               dilations=None):
    """Creates a ConvDiagonalFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      params: The parameters (Tensor or tuple of Tensors) of this layer. If
        kernel alone, a Tensor of shape [kernel_height, kernel_width,
        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
        containing the previous and a Tensor of shape [out_channels].
      strides: The stride size in this layer (1-D Tensor of length 4).
      padding: The padding in this layer (e.g. "SAME").
      data_format: str or None. Format of input data.
      dilations: List of 4 ints or None. Rate for dilation along all dimensions.

    Raises:
      ValueError: if strides is not length-4.
      ValueError: if dilations is not length-4.
      ValueError: if channel is not last dimension.
    """
    if len(strides) != 4:
      raise ValueError("strides must contain 4 numbers.")

    if dilations is None:
      dilations = [1, 1, 1, 1]

    if len(dilations) != 4:
      raise ValueError("dilations must contain 4 numbers.")

    if not utils.is_data_format_channel_last(data_format):
      raise ValueError("data_format must be channels-last.")

    self._strides = maybe_tuple(strides)
    self._padding = padding
    self._data_format = data_format
    self._dilations = maybe_tuple(dilations)
    self._has_bias = isinstance(params, (tuple, list))

    fltr = params[0] if self._has_bias else params
    self._filter_shape = tuple(fltr.shape.as_list())

    if len(self._filter_shape) != 4:
      raise ValueError(
          "Convolution filter must be of shape"
          " [filter_height, filter_width, in_channels, out_channels].")

    super(ConvDiagonalFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    # Infer number of locations upon which convolution is applied.
    self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
                                             self._strides)

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvDiagonalFactor,
        (inputs, grads_list, self._filter_shape, self._strides, self._padding,
         self._data_format, self._dilations, self._has_bias))

    def damping_func():
      return self._num_locations * normalize_damping(damping,
                                                     self._num_locations)

    damping_id = (self._num_locations, "mult", "normalize_damping", damping,
                  self._num_locations)
    self._damping_func = _package_func(damping_func, damping_id)


class KroneckerProductFB(FisherBlock):
  """A base class for blocks with separate input and output Kronecker factors.

  The Fisher block is approximated as a Kronecker product of the input and
  output factors.
  """

  def _setup_damping(self, damping, normalization=None):
    """Makes functions that compute the damping values for both factors."""
    def compute_damping():
      if normalization is not None:
        maybe_normalized_damping = normalize_damping(damping, normalization)
      else:
        maybe_normalized_damping = damping

      return compute_pi_adjusted_damping(
          self._input_factor.get_cov_as_linear_operator(),
          self._output_factor.get_cov_as_linear_operator(),
          maybe_normalized_damping**0.5)

    if normalization is not None:
      damping_id = ("compute_pi_adjusted_damping",
                    "cov", self._input_factor.name,
                    "cov", self._output_factor.name,
                    "normalize_damping", damping, normalization, "power", 0.5)
    else:
      damping_id = ("compute_pi_adjusted_damping",
                    "cov", self._input_factor.name,
                    "cov", self._output_factor.name,
                    damping, "power", 0.5)

    self._input_damping_func = _package_func(lambda: compute_damping()[0],
                                             damping_id + ("ref", 0))
    self._output_damping_func = _package_func(lambda: compute_damping()[1],
                                              damping_id + ("ref", 1))

  def register_matpower(self, exp):
    self._input_factor.register_matpower(exp, self._input_damping_func)
    self._output_factor.register_matpower(exp, self._output_damping_func)

  def register_cholesky(self):
    self._input_factor.register_cholesky(self._input_damping_func)
    self._output_factor.register_cholesky(self._output_damping_func)

  def register_cholesky_inverse(self):
    self._input_factor.register_cholesky_inverse(self._input_damping_func)
    self._output_factor.register_cholesky_inverse(self._output_damping_func)

  @property
  def _renorm_coeff(self):
    """Kronecker factor multiplier coefficient.

    If this FisherBlock is represented as 'FB = c * kron(left, right)', then
    this is 'c'.

    Returns:
      0-D Tensor.
    """
    return 1.0

  def _multiply_factored_matrix(self, left_factor, right_factor, vector,
                                extra_scale=1.0, transpose_left=False,
                                transpose_right=False):
    reshaped_vector = utils.layer_params_to_mat2d(vector)
    reshaped_out = right_factor.matmul_right(reshaped_vector,
                                             adjoint=transpose_right)
    reshaped_out = left_factor.matmul(reshaped_out,
                                      adjoint=transpose_left)
    if extra_scale != 1.0:
      reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
    return utils.mat2d_to_layer_params(vector, reshaped_out)

  def multiply_matpower(self, vector, exp):
    left_factor = self._input_factor.get_matpower(
        exp, self._input_damping_func)
    right_factor = self._output_factor.get_matpower(
        exp, self._output_damping_func)
    extra_scale = float(self._renorm_coeff)**exp
    return self._multiply_factored_matrix(left_factor, right_factor, vector,
                                          extra_scale=extra_scale)

  def multiply_cholesky(self, vector, transpose=False):
    left_factor = self._input_factor.get_cholesky(self._input_damping_func)
    right_factor = self._output_factor.get_cholesky(self._output_damping_func)
    extra_scale = float(self._renorm_coeff)**0.5
    return self._multiply_factored_matrix(left_factor, right_factor, vector,
                                          extra_scale=extra_scale,
                                          transpose_left=transpose,
                                          transpose_right=not transpose)

  def multiply_cholesky_inverse(self, vector, transpose=False):
    left_factor = self._input_factor.get_cholesky_inverse(
        self._input_damping_func)
    right_factor = self._output_factor.get_cholesky_inverse(
        self._output_damping_func)
    extra_scale = float(self._renorm_coeff)**-0.5
    return self._multiply_factored_matrix(left_factor, right_factor, vector,
                                          extra_scale=extra_scale,
                                          transpose_left=transpose,
                                          transpose_right=not transpose)

  def full_fisher_block(self):
    """Explicitly constructs the full Fisher block.

    Used for testing purposes. (In general, the result may be very large.)

    Returns:
      The full Fisher block.
    """
    left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
    right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
    return self._renorm_coeff * utils.kronecker_product(left_factor,
                                                        right_factor)


class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB):
  """K-FAC FisherBlock for embedding layers.

  This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
  input factor is approximated by a diagonal matrix. In the case that each
  example references exactly one embedding, this approximation is exact.

  Does not support bias parameters.
  """

  def __init__(self, layer_collection, vocab_size):
    """Creates a EmbeddingKFACFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      vocab_size: int. Size of vocabulary for this embedding layer.
    """
    self._vocab_size = vocab_size

    super(EmbeddingKFACFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    """Instantiate Kronecker Factors for this FisherBlock.

    Args:
      grads_list: List of list of Tensors. grads_list[i][j] is the
        gradient of the loss with respect to 'outputs' from source 'i' and
        tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
      damping: 0-D Tensor or float. 'damping' * identity is approximately added
        to this FisherBlock's Fisher approximation.
    """
    inputs, grads_list = self._process_data(grads_list)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.EmbeddingInputKroneckerFactor,
        (inputs, self._vocab_size))
    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
    self._setup_damping(damping)


class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
  """K-FAC FisherBlock for fully-connected (dense) layers.

  This uses the Kronecker-factorized approximation from the original
  K-FAC paper (https://arxiv.org/abs/1503.05671)
  """

  def __init__(self, layer_collection, has_bias=False):
    """Creates a FullyConnectedKFACBasicFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      has_bias: Whether the component Kronecker factors have an additive bias.
          (Default: False)
    """
    self._has_bias = has_bias

    super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    """Instantiate Kronecker Factors for this FisherBlock.

    Args:
      grads_list: List of list of Tensors. grads_list[i][j] is the
        gradient of the loss with respect to 'outputs' from source 'i' and
        tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
      damping: 0-D Tensor or float. 'damping' * identity is approximately added
        to this FisherBlock's Fisher approximation.
    """
    inputs, grads_list = self._process_data(grads_list)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedKroneckerFactor,
        ((inputs,), self._has_bias))
    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedKroneckerFactor,
        (grads_list,))
    self._setup_damping(damping)


class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
  r"""FisherBlock for convolutional layers using the basic KFC approx.

  Estimates the Fisher Information matrix's blog for a convolutional
  layer.

  Consider a convolutional layer in this model with (unshared) filter matrix
  'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
  this FisherBlock estimates,

    $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T],
                                  E[flat(ds) flat(ds)^T])$$

  where

    $$ds = (d / ds) log p(y | x, w)$$
    #locations = number of (x, y) locations where 'w' is applied.

  where the expectation is taken over all examples and locations and flat()
  concatenates an array's leading dimensions.

  See equation 23 in https://arxiv.org/abs/1602.01407 for details.
  """

  def __init__(self,
               layer_collection,
               params,
               padding,
               strides=None,
               dilation_rate=None,
               data_format=None,
               extract_patches_fn=None):
    """Creates a ConvKFCBasicFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      params: The parameters (Tensor or tuple of Tensors) of this layer. If
        kernel alone, a Tensor of shape [..spatial_filter_shape..,
        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
        containing the previous and a Tensor of shape [out_channels].
      padding: str. Padding method.
      strides: List of ints or None. Contains [..spatial_filter_strides..] if
        'extract_patches_fn' is compatible with tf.nn.convolution(), else
        [1, ..spatial_filter_strides, 1].
      dilation_rate: List of ints or None. Rate for dilation along each spatial
        dimension if 'extract_patches_fn' is compatible with
        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
      data_format: str or None. Format of input data.
      extract_patches_fn: str or None. Name of function that extracts image
        patches. One of "extract_convolution_patches", "extract_image_patches",
        "extract_pointwise_conv2d_patches".
    """
    self._padding = padding
    self._strides = maybe_tuple(strides)
    self._dilation_rate = maybe_tuple(dilation_rate)
    self._data_format = data_format
    self._extract_patches_fn = extract_patches_fn
    self._has_bias = isinstance(params, (tuple, list))

    fltr = params[0] if self._has_bias else params
    self._filter_shape = tuple(fltr.shape.as_list())

    super(ConvKFCBasicFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    # Infer number of locations upon which convolution is applied.
    self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
                                             self._strides)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvInputKroneckerFactor,
        (inputs, self._filter_shape, self._padding, self._strides,
         self._dilation_rate, self._data_format, self._extract_patches_fn,
         self._has_bias))
    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvOutputKroneckerFactor, (grads_list,))

    self._setup_damping(damping, normalization=self._num_locations)

  @property
  def _renorm_coeff(self):
    return self._num_locations


class DepthwiseConvDiagonalFB(ConvDiagonalFB):
  """FisherBlock for depthwise_conv2d().

  Equivalent to ConvDiagonalFB applied to each input channel in isolation.
  """

  def __init__(self,
               layer_collection,
               params,
               strides,
               padding,
               rate=None,
               data_format=None):
    """Creates a DepthwiseConvKFCBasicFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      params: Tensor of shape [filter_height, filter_width, in_channels,
        channel_multiplier].
      strides: List of 4 ints. Strides along all dimensions.
      padding: str. Padding method.
      rate: List of 4 ints or None. Rate for dilation along all dimensions.
      data_format: str or None. Format of input data.

    Raises:
      NotImplementedError: If parameters contains bias.
      ValueError: If filter is not 4-D.
      ValueError: If strides is not length-4.
      ValueError: If rates is not length-2.
      ValueError: If channels are not last dimension.
    """
    if isinstance(params, (tuple, list)):
      raise NotImplementedError("Bias not yet supported.")

    if params.shape.ndims != 4:
      raise ValueError("Filter must be 4-D.")

    if len(strides) != 4:
      raise ValueError("strides must account for 4 dimensions.")

    if rate is not None:
      if len(rate) != 2:
        raise ValueError("rate must only account for spatial dimensions.")
      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.

    if not utils.is_data_format_channel_last(data_format):
      raise ValueError("data_format must be channels-last.")

    super(DepthwiseConvDiagonalFB, self).__init__(
        layer_collection=layer_collection,
        params=params,
        strides=strides,
        padding=padding,
        dilations=rate,
        data_format=data_format)

    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
    filter_height, filter_width, in_channels, channel_multiplier = (
        params.shape.as_list())
    self._filter_shape = (filter_height, filter_width, in_channels,
                          in_channels * channel_multiplier)

  def _multiply_matrix(self, matrix, vector):
    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
    conv2d_result = super(
        DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)


class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
  """FisherBlock for depthwise_conv2d().

  Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
  """

  def __init__(self,
               layer_collection,
               params,
               strides,
               padding,
               rate=None,
               data_format=None):
    """Creates a DepthwiseConvKFCBasicFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      params: Tensor of shape [filter_height, filter_width, in_channels,
        channel_multiplier].
      strides: List of 4 ints. Strides along all dimensions.
      padding: str. Padding method.
      rate: List of 4 ints or None. Rate for dilation along all dimensions.
      data_format: str or None. Format of input data.

    Raises:
      NotImplementedError: If parameters contains bias.
      ValueError: If filter is not 4-D.
      ValueError: If strides is not length-4.
      ValueError: If rates is not length-2.
      ValueError: If channels are not last dimension.
    """
    if isinstance(params, (tuple, list)):
      raise NotImplementedError("Bias not yet supported.")

    if params.shape.ndims != 4:
      raise ValueError("Filter must be 4-D.")

    if len(strides) != 4:
      raise ValueError("strides must account for 4 dimensions.")

    if rate is not None:
      if len(rate) != 2:
        raise ValueError("rate must only account for spatial dimensions.")
      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.

    if not utils.is_data_format_channel_last(data_format):
      raise ValueError("data_format must be channels-last.")

    super(DepthwiseConvKFCBasicFB, self).__init__(
        layer_collection=layer_collection,
        params=params,
        padding=padding,
        strides=strides,
        dilation_rate=rate,
        data_format=data_format,
        extract_patches_fn="extract_image_patches")

    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
    filter_height, filter_width, in_channels, channel_multiplier = (
        params.shape.as_list())
    self._filter_shape = (filter_height, filter_width, in_channels,
                          in_channels * channel_multiplier)

  def _multiply_factored_matrix(self, left_factor, right_factor, vector,
                                extra_scale=1.0, transpose_left=False,
                                transpose_right=False):
    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
    conv2d_result = super(
        DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
            left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
            transpose_left=transpose_left, transpose_right=transpose_right)
    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)


def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin
  """Converts a convolution filter for use with conv2d.

  Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
  compatible with tf.nn.conv2d().

  Args:
    filter: Tensor of shape [height, width, in_channels, channel_multiplier].
    name: None or str. Name of Op.

  Returns:
    Tensor of shape [height, width, in_channels, out_channels].

  """
  with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
                      [filter]):
    filter = ops.convert_to_tensor(filter)
    filter_height, filter_width, in_channels, channel_multiplier = (
        filter.shape.as_list())

    results = []
    for i in range(in_channels):
      # Slice out one in_channel's filter. Insert zeros around it to force it
      # to affect that channel and that channel alone.
      elements = []
      if i > 0:
        elements.append(
            array_ops.zeros(
                [filter_height, filter_width, i, channel_multiplier]))
      elements.append(filter[:, :, i:(i + 1), :])
      if i + 1 < in_channels:
        elements.append(
            array_ops.zeros([
                filter_height, filter_width, in_channels - (i + 1),
                channel_multiplier
            ]))

      # Concat along in_channel.
      results.append(
          array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))

    # Concat along out_channel.
    return array_ops.concat(results, axis=-1, name="out_channel")


def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin
  """Converts a convolution filter for use with depthwise_conv2d.

  Transforms a filter for use with tf.nn.conv2d() to one that's
  compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
  the diagonal.

  Args:
    filter: Tensor of shape [height, width, in_channels, out_channels].
    name: None or str. Name of Op.

  Returns:
    Tensor of shape,
      [height, width, in_channels, channel_multiplier]

  Raises:
    ValueError: if out_channels is not evenly divisible by in_channels.
  """
  with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
                      [filter]):
    filter = ops.convert_to_tensor(filter)
    filter_height, filter_width, in_channels, out_channels = (
        filter.shape.as_list())

    if out_channels % in_channels != 0:
      raise ValueError("out_channels must be evenly divisible by in_channels.")
    channel_multiplier = out_channels // in_channels

    results = []
    filter = array_ops.reshape(filter, [
        filter_height, filter_width, in_channels, in_channels,
        channel_multiplier
    ])
    for i in range(in_channels):
      # Slice out output corresponding to the correct filter.
      filter_slice = array_ops.reshape(
          filter[:, :, i, i, :],
          [filter_height, filter_width, 1, channel_multiplier])
      results.append(filter_slice)

    # Concat along out_channel.
    return array_ops.concat(results, axis=-2, name="in_channels")


def maybe_tuple(obj):
  if not isinstance(obj, list):
    return obj
  return tuple(obj)


def num_conv_locations(input_shape, strides):
  """Returns the number of spatial locations a 2D Conv kernel is applied to.

  Args:
    input_shape: List of ints representing shape of inputs to
      tf.nn.convolution().
    strides: List of ints representing strides along spatial dimensions as
      passed in to tf.nn.convolution().

  Returns:
    A scalar |T| denoting the number of spatial locations for the Conv layer.
  """
  spatial_input_locations = np.prod(input_shape[1:-1])

  if strides is None:
    spatial_strides_divisor = 1
  else:
    spatial_strides_divisor = np.prod(strides)

  return spatial_input_locations // spatial_strides_divisor


class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
  """Adds methods for multi-use/time-step case to InputOutputMultiTower."""

  def __init__(self, num_uses=None, *args, **kwargs):
    self._num_uses = num_uses
    super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)

  def _process_data(self, grads_list):
    """Process temporal/multi-use data into the format used by the factors.

    This function takes inputs and grads_lists data and processes it into
    one of the formats expected by the FisherFactor classes (depending on
    the value of the global configuration variable TOWER_STRATEGY).

    It accepts the data in one of two initial formats. The first possible
    format is where self._inputs is a list of list of Tensors. The first index
    is tower, the second is use/time-step. grads_list, meanwhile, is a list
    over sources of such lists of lists.

    The second possible data format is where self._inputs is a Tensor with
    uses/times-steps folded into the batch dimension.  i.e. it is a Tensor
    of shape [num_uses * size_batch, ...] which represents a reshape of a
    Tensor of shape [num_uses, size_batch, ...].  And similarly grads_list is
    a list over sources of such Tensors.

    There are two possible formats which inputs and grads_list are transformed
    into.

    If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
    a single tensor (represented as a PartitionedTensor object) with all of
    the data from the towers, as well as the uses/time-steps, concatenated
    together. In this tensor the leading dimension is the batch and
    use/time-step dimensions folded together (with 'use' being the major of
    these two, so that the tensors can be thought of as reshapes of ones of
    shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
    tuple over sources of such tensors.

    If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
    tensors over towers. Each of these tensors has a similar format to
    the tensor produced by the "concat" option, except that each contains
    only the data from a single tower.  grads_list is similarly formatted
    into a tuple over sources of such tuples.

    Args:
      grads_list: grads_list in its initial format (see above).

    Returns:
      inputs: self._inputs transformed into the appropriate format (see
        above).
      grads_list: grads_list transformed into the appropriate format (see
        above).

    Raises:
      ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
      ValueError: If the given/initial format of self._inputs and grads_list
        isn't recognized, or doesn't agree with self._num_uses.
    """

    inputs = self._inputs

    if isinstance(inputs[0], (list, tuple)):
      num_uses = len(inputs[0])
      if self._num_uses is not None and self._num_uses != num_uses:
        raise ValueError("num_uses argument doesn't match length of inputs.")
      else:
        self._num_uses = num_uses

      # Check that all mini-batches/towers have the same number of uses
      if not all(len(input_) == num_uses for input_ in inputs):
        raise ValueError("Length of inputs argument is inconsistent across "
                         "towers.")

      if fisher_factors.TOWER_STRATEGY == "concat":
        # Reverse the tower and use/time-step indices, so that use is now first,
        # and towers is second
        inputs = tuple(zip(*inputs))

        # Flatten the two dimensions
        inputs = nest.flatten(inputs)

        # Merge everything together into a PartitionedTensor. We package it in
        # a singleton tuple since the factors will expect a list over towers
        inputs = (utils.PartitionedTensor(inputs),)

      elif fisher_factors.TOWER_STRATEGY == "separate":
        # Merge together the uses/time-step dimension into PartitionedTensors,
        # but keep the leading dimension (towers) intact for the factors to
        # process individually.
        inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)

      else:
        raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                         "'concat' or 'separate'.")
    else:
      inputs = tuple(inputs)

    # Now we perform the analogous processing for grads_list
    if isinstance(grads_list[0][0], (list, tuple)):
      num_uses = len(grads_list[0][0])
      if self._num_uses is not None and self._num_uses != num_uses:
        raise ValueError("num_uses argument doesn't match length of outputs, "
                         "or length of outputs is inconsistent with length of "
                         "inputs.")
      else:
        self._num_uses = num_uses

      if not all(len(grad) == num_uses for grads in grads_list
                 for grad in grads):
        raise ValueError("Length of outputs argument is inconsistent across "
                         "towers.")

      if fisher_factors.TOWER_STRATEGY == "concat":
        # Reverse the tower and use/time-step indices, so that use is now first,
        # and towers is second
        grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)

        # Flatten the two dimensions, leaving the leading dimension (source)
        # intact
        grads_list = tuple(nest.flatten(grads) for grads in grads_list)

        # Merge inner dimensions together into PartitionedTensors. We package
        # them in a singleton tuple since the factors will expect a list over
        # towers
        grads_list = tuple((utils.PartitionedTensor(grads),)
                           for grads in grads_list)

      elif fisher_factors.TOWER_STRATEGY == "separate":
        # Merge together the uses/time-step dimension into PartitionedTensors,
        # but keep the leading dimension (towers) intact for the factors to
        # process individually.
        grads_list = tuple(tuple(utils.PartitionedTensor(grad)
                                 for grad in grads)
                           for grads in grads_list)

      else:
        raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                         "'concat' or 'separate'.")
    else:
      grads_list = tuple(tuple(grads) for grads in grads_list)

    if self._num_uses is None:
      raise ValueError("You must supply a value for the num_uses argument if "
                       "the number of uses cannot be inferred from inputs or "
                       "outputs arguments (e.g. if they are both given in the "
                       "single Tensor format, instead of as lists of Tensors.")

    return inputs, grads_list


class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
                                 KroneckerProductFB):
  """FisherBlock for fully-connected layers that share parameters.

  This class implements the "independence across time" approximation from the
  following paper:
    https://openreview.net/pdf?id=HyMTkQZAb
  """

  def __init__(self, layer_collection, has_bias=False, num_uses=None):
    """Creates a FullyConnectedMultiIndepFB block.

    Args:
      layer_collection: LayerCollection instance.
      has_bias: bool. If True, estimates Fisher with respect to a bias
        parameter as well as the layer's parameters.
      num_uses: int or None. Number of uses of the layer in the model's graph.
        Only required if the data is formatted with uses/time folded into the
        batch dimension (instead of uses/time being a list dimension).
        (Default: None)
    """
    self._has_bias = has_bias

    super(FullyConnectedMultiIndepFB, self).__init__(
        layer_collection=layer_collection,
        num_uses=num_uses)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedMultiKF,
        ((inputs,), self._num_uses, self._has_bias))

    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))

    self._setup_damping(damping, normalization=self._num_uses)

  @property
  def _renorm_coeff(self):
    return float(self._num_uses)


class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
                               KroneckerProductFB):
  """FisherBlock for 2D convolutional layers using the basic KFC approx.

  Similar to ConvKFCBasicFB except that this version supports multiple
  uses/time-steps via a standard independence approximation.  Similar to the
  "independence across time" used in FullyConnectedMultiIndepFB but generalized
  in the obvious way to conv layers.
  """

  def __init__(self,
               layer_collection,
               params,
               padding,
               strides=None,
               dilation_rate=None,
               data_format=None,
               extract_patches_fn=None,
               num_uses=None):
    """Creates a ConvKFCBasicMultiIndepFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      params: The parameters (Tensor or tuple of Tensors) of this layer. If
        kernel alone, a Tensor of shape [..spatial_filter_shape..,
        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
        containing the previous and a Tensor of shape [out_channels].
      padding: str. Padding method.
      strides: List of ints or None. Contains [..spatial_filter_strides..] if
        'extract_patches_fn' is compatible with tf.nn.convolution(), else
        [1, ..spatial_filter_strides, 1].
      dilation_rate: List of ints or None. Rate for dilation along each spatial
        dimension if 'extract_patches_fn' is compatible with
        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
      data_format: str or None. Format of input data.
      extract_patches_fn: str or None. Name of function that extracts image
        patches. One of "extract_convolution_patches", "extract_image_patches",
        "extract_pointwise_conv2d_patches".
      num_uses: int or None. Number of uses of the layer in the model's graph.
        Only required if the data is formatted with uses/time folded into the
        batch dimension (instead of uses/time being a list dimension).
        (Default: None)
    """
    self._padding = padding
    self._strides = maybe_tuple(strides)
    self._dilation_rate = maybe_tuple(dilation_rate)
    self._data_format = data_format
    self._extract_patches_fn = extract_patches_fn
    self._has_bias = isinstance(params, (tuple, list))

    fltr = params[0] if self._has_bias else params
    self._filter_shape = tuple(fltr.shape.as_list())

    super(ConvKFCBasicMultiIndepFB, self).__init__(
        layer_collection=layer_collection,
        num_uses=num_uses)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    # Infer number of locations upon which convolution is applied.
    self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
                                             self._strides)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvInputKroneckerFactor,
        (inputs, self._filter_shape, self._padding, self._strides,
         self._dilation_rate, self._data_format, self._extract_patches_fn,
         self._has_bias))
    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvOutputKroneckerFactor, (grads_list,))

    self._setup_damping(damping, normalization=
                        (self._num_locations * self._num_uses))

  @property
  def _renorm_coeff(self):
    return self._num_locations * self._num_uses


class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse,
                                KroneckerProductFB):
  """K-FAC FisherBlock for embedding layers used multiple times in the graph.

  Similar to EmbeddingKFACFB except that this version supports multiple uses
  of the parameter within a single model. These uses could correspond to time
  steps in an RNN architecture, but they don't have to.

  Does not support bias parameters.
  """

  def __init__(self, layer_collection, vocab_size, num_uses=None):
    """Creates a EmbeddingKFACMultiIndepFB block.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
          Fisher information matrix to which this FisherBlock belongs.
      vocab_size: int. Size of vocabulary for this embedding layer.
      num_uses: int or None. Number of uses of the layer in the model's graph.
        Only required if the data is formatted with time folded into the batch
        dimension (instead of time being a list dimension). (Default: None)
    """
    self._vocab_size = vocab_size

    super(EmbeddingKFACMultiIndepFB, self).__init__(
        layer_collection=layer_collection,
        num_uses=num_uses)

  def instantiate_factors(self, grads_list, damping):
    """Instantiate Kronecker Factors for this FisherBlock.

    Args:
      grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
        gradient of the loss with respect to 'outputs' from source 'i',
        tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
        [tower_minibatch_size, output_size].
      damping: 0-D Tensor or float. 'damping' * identity is approximately added
        to this FisherBlock's Fisher approximation.
    """
    inputs, grads_list = self._process_data(grads_list)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.EmbeddingInputKroneckerFactor,
        (inputs, self._vocab_size))
    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
    self._setup_damping(damping, normalization=self._num_uses)

  @property
  def _renorm_coeff(self):
    return float(self._num_uses)


class SeriesFBApproximation(enum.IntEnum):
  """See FullyConnectedSeriesFB.__init__ for description and usage."""
  option1 = 1
  option2 = 2


class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
                             KroneckerProductFB):
  """FisherBlock for fully-connected layers that share parameters across time.

  This class implements the "Option 1" and "Option 2" approximation from the
  following paper:
    https://openreview.net/pdf?id=HyMTkQZAb

  See the end of the appendix of the paper for a pseudo-code of the
  algorithm being implemented by multiply_matpower here.  Note that we are
  using pre-computed versions of certain matrix-matrix products to speed
  things up.  This is explicitly explained wherever it is done.
  """

  def __init__(self,
               layer_collection,
               has_bias=False,
               num_uses=None,
               option=SeriesFBApproximation.option2):
    """Constructs a new `FullyConnectedSeriesFB`.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
        Fisher information matrix to which this FisherBlock belongs.
      has_bias: Whether the layer includes a bias parameter.
      num_uses: int or None. Number of time-steps over which the layer
        is used. Only required if the data is formatted with time folded into
        the batch dimension (instead of time being a list dimension).
        (Default: None)
      option: A `SeriesFBApproximation` specifying the simplifying assumption
        to be used in this block. `option1` approximates the cross-covariance
        over time as a symmetric matrix, while `option2` makes
        the assumption that training sequences are infinitely long. See section
        3.5 of the paper for more details.
    """

    self._has_bias = has_bias
    self._option = option

    super(FullyConnectedSeriesFB, self).__init__(
        layer_collection=layer_collection,
        num_uses=num_uses)

  @property
  def _num_timesteps(self):
    return self._num_uses

  @property
  def _renorm_coeff(self):
    # This should no longer be used since the multiply_X functions from the base
    # class have been overridden
    assert False

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedMultiKF,
        ((inputs,), self._num_uses, self._has_bias))
    self._input_factor.register_cov_dt1()

    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
    self._output_factor.register_cov_dt1()

    self._setup_damping(damping, normalization=self._num_uses)

  def register_matpower(self, exp):
    if exp != -1:
      raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
                                "multiplications.")

    if self._option == SeriesFBApproximation.option1:
      self._input_factor.register_option1quants(self._input_damping_func)
      self._output_factor.register_option1quants(self._output_damping_func)
    elif self._option == SeriesFBApproximation.option2:
      self._input_factor.register_option2quants(self._input_damping_func)
      self._output_factor.register_option2quants(self._output_damping_func)
    else:
      raise ValueError(
          "Unrecognized FullyConnectedSeriesFB approximation: {}".format(
              self._option))

  def multiply_matpower(self, vector, exp):
    if exp != -1:
      raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
                                "multiplications.")

    # pylint: disable=invalid-name

    Z = utils.layer_params_to_mat2d(vector)

    # Derivations were done for "batch_dim==1" case so we need to convert to
    # that orientation:
    Z = array_ops.transpose(Z)

    if self._option == SeriesFBApproximation.option1:

      # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\)
      L_A, psi_A = self._input_factor.get_option1quants(
          self._input_damping_func)
      L_G, psi_G = self._output_factor.get_option1quants(
          self._output_damping_func)

      def gamma(x):
        # We are assuming that each case has the same number of time-steps.
        # If this stops being the case one shouldn't simply replace this T
        # with its average value.  Instead, one needs to go back to the
        # definition of the gamma function from the paper.
        T = self._num_timesteps
        return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))

      # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise)
      # Even though Y is Z-independent we are recomputing it from the psi's
      # each since Y depends on both A and G quantities, and it is relatively
      # cheap to compute.
      Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)

      # \\(Z = L_G^T * Z * L_A\\)
      # This is equivalent to the following computation from the original
      # pseudo-code:
      # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
      # \\(Z = U_G^T * Z * U_A\\)
      Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True)

      # \\(Z = Z .* Y\\)
      Z *= Y

      # \\(Z = L_G * Z * L_A^T\\)
      # This is equivalent to the following computation from the original
      # pseudo-code:
      # \\(Z = U_G * Z * U_A^T\\)
      # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
      Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True))

    elif self._option == SeriesFBApproximation.option2:

      # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\),
      # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\)
      P_A, K_A, mu_A = self._input_factor.get_option2quants(
          self._input_damping_func)
      P_G, K_G, mu_G = self._output_factor.get_option2quants(
          self._output_damping_func)

      # Our approach differs superficially from the pseudo-code in the paper
      # in order to reduce the total number of matrix-matrix multiplies.
      # In particular, the first three computations in the pseudo code are
      # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
      # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\)
      # \\(Z = E_G^T * Z * E_A\\)
      # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that
      # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\)
      # the entire computation can be written as
      # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
      # \\(    - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\)
      # \\(  = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
      # \\(    - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\)
      # \\(  = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\)
      # \\(    -  E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\)
      # \\(  = K_G^T * Z * K_A  -  K_G^T * P_G * Z * P_A^T * K_A\\)
      # This final expression is computed by the following two lines:
      # \\(Z = Z - P_G * Z * P_A^T\\)
      Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True))
      # \\(Z = K_G^T * Z * K_A\\)
      Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True)

      # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\)
      # Be careful with the outer product.  We don't want to accidentally
      # make it an inner-product instead.
      tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
      # Prevent some numerical issues by setting any 0.0 eigs to 1.0
      tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype)
      Z /= tmp

      # We now perform the transpose/reverse version of the operations
      # derived above, whose derivation from the original pseudo-code is
      # analgous.
      # \\(Z = K_G * Z * K_A^T\\)
      Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True))

      # \\(Z = Z - P_G^T * Z * P_A\\)
      Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True)

      # \\(Z = normalize (1/E[T]) * Z\\)
      # Note that this normalization is done because we compute the statistics
      # by averaging, not summing, over time. (And the gradient is presumably
      # summed over time, not averaged, and thus their scales are different.)
      Z /= math_ops.cast(self._num_timesteps, Z.dtype)

    # Convert back to the "batch_dim==0" orientation.
    Z = array_ops.transpose(Z)

    return utils.mat2d_to_layer_params(vector, Z)

    # pylint: enable=invalid-name

  def multiply_cholesky(self, vector):
    raise NotImplementedError("FullyConnectedSeriesFB does not support "
                              "Cholesky computations.")

  def multiply_cholesky_inverse(self, vector):
    raise NotImplementedError("FullyConnectedSeriesFB does not support "
                              "Cholesky computations.")