diff options
author | James Martens <jamesmartens@google.com> | 2018-04-26 04:37:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-26 04:40:47 -0700 |
commit | 481f229881c915fec0822f68c6ce0ebbb9983da0 (patch) | |
tree | e92807d0dfff38c86aafb6a83649137911bef2a0 /tensorflow/contrib/kfac | |
parent | 8148895adc1cf35112fb7197a798bc825a61e4f6 (diff) |
- Adding support for Cholesky (inverse) factor multiplications.
- Refactored FisherFactor to use LinearOperator classes that know how to multiply themselves, compute their own trace, etc. This addresses the feature request: b/73356352
- Fixed some problems with FisherEstimator construction
- More careful casting of damping constants before they are used
PiperOrigin-RevId: 194379298
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py | 106 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/estimator.py | 69 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/estimator_lib.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/fisher_blocks.py | 271 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/fisher_factors.py | 322 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/linear_operator.py | 95 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/placement.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/utils.py | 16 |
11 files changed, 632 insertions, 277 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 2477d2bfc1..c2436affe2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -58,6 +58,7 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:linear_operator", "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 6eda6c31e3..566d393f45 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -46,8 +47,9 @@ class UtilsTest(test.TestCase): def testComputePiTracenorm(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) + diag = ops.convert_to_tensor([1., 2., 0., 1.]) + left_factor = lo.LinearOperatorDiag(diag) + right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) # pi is the sqrt of the left trace norm divided by the right trace norm pi = fb.compute_pi_tracenorm(left_factor, right_factor) @@ -245,7 +247,6 @@ class NaiveDiagonalFBTest(test.TestCase): full = sess.run(block.full_fisher_block()) explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - self.assertAllClose(output_flat, explicit) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 432b67e569..9153ddf09c 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -70,35 +70,44 @@ class FisherFactorTestingDummy(ff.FisherFactor): def get_cov(self): return NotImplementedError - def left_multiply(self, x, damping): + def instantiate_inv_variables(self): return NotImplementedError - def right_multiply(self, x, damping): - return NotImplementedError + def _num_towers(self): + raise NotImplementedError - def left_multiply_matpower(self, x, exp, damping): - return NotImplementedError + def _get_data_device(self): + raise NotImplementedError - def right_multiply_matpower(self, x, exp, damping): - return NotImplementedError + def register_matpower(self, exp, damping_func): + raise NotImplementedError - def instantiate_inv_variables(self): - return NotImplementedError + def register_cholesky(self, damping_func): + raise NotImplementedError - def _num_towers(self): + def register_cholesky_inverse(self, damping_func): raise NotImplementedError - def _get_data_device(self): + def get_matpower(self, exp, damping_func): raise NotImplementedError + def get_cholesky(self, damping_func): + raise NotImplementedError -class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): - """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. + def get_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_cov_as_linear_operator(self): + raise NotImplementedError + + +class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): + """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. """ def __init__(self, shape): self._shape = shape - super(InverseProvidingFactorTestingDummy, self).__init__() + super(DenseSquareMatrixFactorTestingDummy, self).__init__() @property def _var_scope(self): @@ -230,13 +239,13 @@ class FisherFactorTest(test.TestCase): self.assertEqual(0, len(factor.make_inverse_update_ops())) -class InverseProvidingFactorTest(test.TestCase): +class DenseSquareMatrixFactorTest(test.TestCase): def testRegisterDampedInverse(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [2, 2] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' damping_funcs = [make_damping_func(0.1), @@ -248,22 +257,25 @@ class InverseProvidingFactorTest(test.TestCase): factor.instantiate_inv_variables() - inv = factor.get_inverse(damping_funcs[0]) - self.assertEqual(inv, factor.get_inverse(damping_funcs[1])) - self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2])) - self.assertEqual(factor.get_inverse(damping_funcs[2]), - factor.get_inverse(damping_funcs[3])) + inv = factor.get_inverse(damping_funcs[0]).to_dense() + self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) + self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), + factor.get_inverse(damping_funcs[3]).to_dense()) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]), - set(factor_vars)) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + self.assertEqual(set([inv, + factor.get_inverse(damping_funcs[2]).to_dense()]), + set(factor_tensors)) self.assertEqual(shape, inv.get_shape()) def testRegisterMatpower(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [3, 3] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' # TODO(b/74201126): Change to using the same func for both once @@ -278,10 +290,13 @@ class InverseProvidingFactorTest(test.TestCase): factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - matpower1 = factor.get_matpower(-0.5, damping_func_1) - matpower2 = factor.get_matpower(2, damping_func_2) - self.assertEqual(set([matpower1, matpower2]), set(factor_vars)) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() + matpower2 = factor.get_matpower(2, damping_func_2).to_dense() + + self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) self.assertEqual(shape, matpower1.get_shape()) self.assertEqual(shape, matpower2.get_shape()) @@ -297,7 +312,7 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[1., 2.], [3., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) damping_funcs = [] @@ -316,7 +331,8 @@ class InverseProvidingFactorTest(test.TestCase): sess.run(ops) for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append(sess.run(factor.get_inverse(damping_funcs[i]))) + new_invs.append( + sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -328,7 +344,7 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[6., 2.], [2., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power damping = 0.5 @@ -341,7 +357,7 @@ class InverseProvidingFactorTest(test.TestCase): sess.run(tf_variables.global_variables_initializer()) sess.run(ops[0]) - matpower = sess.run(factor.get_matpower(exp, damping_func)) + matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) self.assertAllClose(matpower, matpower_np) @@ -349,7 +365,7 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) damping_func = make_damping_func(0) @@ -361,12 +377,12 @@ class InverseProvidingFactorTest(test.TestCase): sess.run(tf_variables.global_variables_initializer()) # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor.get_inverse(damping_func)) + old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose( sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) sess.run(ops) - new_inv = sess.run(factor.get_inverse(damping_func)) + new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose(new_inv, np.linalg.inv(cov)) @@ -411,7 +427,7 @@ class NaiveDiagonalFactorTest(test.TestCase): tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) factor.instantiate_cov_variables() - self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) + self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -420,7 +436,7 @@ class NaiveDiagonalFactorTest(test.TestCase): tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) factor.instantiate_cov_variables() - cov = factor.get_cov_var() + cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -444,7 +460,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) factor.instantiate_cov_variables() - cov = factor.get_cov_var() + cov = factor.get_cov() self.assertEqual(cov.shape.as_list(), [vocab_size]) def testCovarianceUpdateOp(self): @@ -502,7 +518,7 @@ class ConvDiagonalFactorTest(test.TestCase): self.kernel_height * self.kernel_width * self.in_channels, self.out_channels ], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(): @@ -564,7 +580,7 @@ class ConvDiagonalFactorTest(test.TestCase): self.kernel_height * self.kernel_width * self.in_channels + 1, self.out_channels ], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure update op doesn't crash. cov_update_op = factor.make_covariance_update_op(0.0) @@ -654,13 +670,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): # Ensure shape of covariance matches input size of filter. input_size = in_channels * (width**3) self.assertEqual([input_size, input_size], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank-8, as the filter will be applied at each corner of # the 4-D cube. @@ -685,13 +701,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): # Ensure shape of covariance matches input size of filter. self.assertEqual([in_channels, in_channels], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank-9, as the filter will be applied at each location. self.assertMatrixRank(9, cov) @@ -716,7 +732,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be the sum of 3 * 2 = 6 outer products. self.assertMatrixRank(6, cov) @@ -742,7 +758,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank = in_channels, as only the center of the filter # receives non-zero input for each input channel. diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index cb0917bb85..3c01eb65e7 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -35,6 +35,7 @@ py_library( srcs = ["fisher_factors.py"], srcs_version = "PY2AND3", deps = [ + ":linear_operator", ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -64,6 +65,19 @@ py_library( ) py_library( + name = "linear_operator", + srcs = ["linear_operator.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/ops/linalg", + "@six_archive//:six", + ], +) + +py_library( name = "loss_functions", srcs = ["loss_functions.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index d11c9c8288..84ebf5e2e2 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs): if placement_strategy in [None, "round_robin"]: return FisherEstimatorRoundRobin(**kwargs) else: - raise ValueError("Unimplemented vars and ops placement strategy : %s", - placement_strategy) + raise ValueError("Unimplemented vars and ops " + "placement strategy : {}".format(placement_strategy)) # pylint: enable=abstract-class-instantiated @@ -81,7 +81,9 @@ class FisherEstimator(object): exps=(-1,), estimation_mode="gradients", colocate_gradients_with_ops=True, - name="FisherEstimator"): + name="FisherEstimator", + compute_cholesky=False, + compute_cholesky_inverse=False): """Create a FisherEstimator object. Args: @@ -124,6 +126,12 @@ class FisherEstimator(object): name: A string. A name given to this estimator, which is added to the variable scope when constructing variables and ops. (Default: "FisherEstimator") + compute_cholesky: Bool. Whether or not the FisherEstimator will be + able to multiply vectors by the Cholesky factor. + (Default: False) + compute_cholesky_inverse: Bool. Whether or not the FisherEstimator + will be able to multiply vectors by the Cholesky factor inverse. + (Default: False) Raises: ValueError: If no losses have been registered with layer_collection. """ @@ -142,6 +150,8 @@ class FisherEstimator(object): self._made_vars = False self._exps = exps + self._compute_cholesky = compute_cholesky + self._compute_cholesky_inverse = compute_cholesky_inverse self._name = name @@ -300,9 +310,54 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ + assert exp in self._exps + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) return self._apply_transformation(vecs_and_vars, fcn) + def multiply_cholesky(self, vecs_and_vars, transpose=False): + """Multiplies the vecs by the corresponding Cholesky factors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factors are transposed before + multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky + + fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): + """Mults the vecs by the inverses of the corresponding Cholesky factors. + + Note: if you are using Cholesky inverse multiplication to sample from + a matrix-variate Gaussian you will want to multiply by the transpose. + Let L be the Cholesky factor of F and observe that + + L^-T * L^-1 = (L * L^T)^-1 = F^-1 . + + Thus we want to multiply by L^-T in order to sample from Gaussian with + covariance F^-1. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factor inverses are transposed + before multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky_inverse + + fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + def _instantiate_factors(self): """Instantiates FisherFactors' variables. @@ -333,9 +388,13 @@ class FisherEstimator(object): return self._made_vars def _register_matrix_functions(self): - for exp in self._exps: - for block in self.blocks: + for block in self.blocks: + for exp in self._exps: block.register_matpower(exp) + if self._compute_cholesky: + block.register_cholesky() + if self._compute_cholesky_inverse: + block.register_cholesky_inverse() def _finalize_layer_collection(self): self._layers.create_subgraph() diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py index 33c9696506..9c9fef471f 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py +++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py @@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'FisherEstimator', + 'make_fisher_estimator', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 00b3673a74..32c776cb38 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications): def compute_pi_tracenorm(left_cov, right_cov): - """Computes the scalar constant pi for Tikhonov regularization/damping. + r"""Computes the scalar constant pi for Tikhonov regularization/damping. $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: - left_cov: The left Kronecker factor "covariance". - right_cov: The right Kronecker factor "covariance". + left_cov: A LinearOperator object. The left Kronecker factor "covariance". + right_cov: A LinearOperator object. The right Kronecker factor "covariance". Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ - - def _trace(cov): - if len(cov.shape) == 1: - # Diagonal matrix. - return math_ops.reduce_sum(cov) - elif len(cov.shape) == 2: - # Full matrix. - return math_ops.trace(cov) - else: - raise ValueError( - "What's the trace of a Tensor of rank %d?" % len(cov.shape)) - # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = left_cov.trace() * int(right_cov.domain_dimension) + right_norm = right_cov.trace() * int(left_cov.domain_dimension) return math_ops.sqrt(left_norm / right_norm) @@ -188,6 +176,16 @@ class FisherBlock(object): """ pass + @abc.abstractmethod + def register_cholesky(self): + """Registers a Cholesky factor to be computed by the block.""" + pass + + @abc.abstractmethod + def register_cholesky_inverse(self): + """Registers an inverse Cholesky factor to be computed by the block.""" + pass + def register_inverse(self): """Registers a matrix inverse to be computed by the block.""" self.register_matpower(-1) @@ -229,6 +227,33 @@ class FisherBlock(object): return self.multiply_matpower(vector, 1) @abc.abstractmethod + def multiply_cholesky(self, vector, transpose=False): + """Multiplies the vector by the (damped) Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor is transposed before + multiplying the vector. (Default: False) + + Returns: + The vector left-multiplied by the (damped) Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def multiply_cholesky_inverse(self, vector, transpose=False): + """Multiplies vector by the (damped) inverse Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor inverse is transposed + before multiplying the vector. (Default: False) + Returns: + Vector left-multiplied by (damped) inverse Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod def tensors_to_compute_grads(self): """Returns the Tensor(s) with respect to which this FisherBlock needs grads. """ @@ -275,15 +300,32 @@ class FullFB(FisherBlock): def register_matpower(self, exp): self._factor.register_matpower(exp, self._damping_func) - def multiply_matpower(self, vector, exp): + def register_cholesky(self): + self._factor.register_cholesky(self._damping_func) + + def register_cholesky_inverse(self): + self._factor.register_cholesky_inverse(self._damping_func) + + def _multiply_matrix(self, matrix, vector, transpose=False): vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_matpower( - vector_flat, exp, self._damping_func) + out_flat = matrix.matmul(vector_flat, adjoint=transpose) return utils.column_to_tensors(vector, out_flat) + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + def full_fisher_block(self): """Explicitly constructs the full Fisher block.""" - return self._factor.get_cov() + return self._factor.get_cov_as_linear_operator().to_dense() def tensors_to_compute_grads(self): return self._params @@ -305,7 +347,47 @@ class FullFB(FisherBlock): return math_ops.reduce_sum(self._batch_sizes) -class NaiveDiagonalFB(FisherBlock): +@six.add_metaclass(abc.ABCMeta) +class DiagonalFB(FisherBlock): + """A base class for FisherBlocks that use diagonal approximations.""" + + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass + + def register_cholesky(self): + # Not needed for this. Cholesky's are computed on demand in the + # diagonal case + pass + + def register_cholesky_inverse(self): + # Not needed for this. Cholesky inverses's are computed on demand in the + # diagonal case + pass + + def _multiply_matrix(self, matrix, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def full_fisher_block(self): + return self._factor.get_cov_as_linear_operator().to_dense() + + +class NaiveDiagonalFB(DiagonalFB): """FisherBlock using a diagonal matrix approximation. This type of approximation is generically applicable but quite primitive. @@ -333,20 +415,6 @@ class NaiveDiagonalFB(FisherBlock): self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def multiply_matpower(self, vector, exp): - vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_matpower( - vector_flat, exp, self._damping_func) - return utils.column_to_tensors(vector, out_flat) - - def full_fisher_block(self): - return self._factor.get_cov() - def tensors_to_compute_grads(self): return self._params @@ -452,7 +520,7 @@ class InputOutputMultiTower(object): return self.__outputs -class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): +class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully @@ -497,32 +565,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): self._damping_func = _package_func(lambda: damping, (damping,)) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def multiply_matpower(self, vector, exp): - """Multiplies the vector by the (damped) matrix-power of the block. - - Args: - vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape - [input_size, output_size] corresponding to layer's weights. If not, a - 2-tuple of the former and a Tensor of shape [output_size] corresponding - to the layer's bias. - exp: A scalar representing the power to raise the block before multiplying - it by the vector. - - Returns: - The vector left-multiplied by the (damped) matrix-power of the block. - """ - reshaped_vec = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_matpower( - reshaped_vec, exp, self._damping_func) - return utils.mat2d_to_layer_params(vector, reshaped_out) - -class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): +class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional @@ -621,17 +665,6 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): self._num_locations) self._damping_func = _package_func(damping_func, damping_id) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def multiply_matpower(self, vector, exp): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_matpower( - reshaped_vect, exp, self._damping_func) - return utils.mat2d_to_layer_params(vector, reshaped_out) - class KroneckerProductFB(FisherBlock): """A base class for blocks with separate input and output Kronecker factors. @@ -651,9 +684,10 @@ class KroneckerProductFB(FisherBlock): else: maybe_normalized_damping = damping - return compute_pi_adjusted_damping(self._input_factor.get_cov(), - self._output_factor.get_cov(), - maybe_normalized_damping**0.5) + return compute_pi_adjusted_damping( + self._input_factor.get_cov_as_linear_operator(), + self._output_factor.get_cov_as_linear_operator(), + maybe_normalized_damping**0.5) if normalization is not None: damping_id = ("compute_pi_adjusted_damping", @@ -675,6 +709,14 @@ class KroneckerProductFB(FisherBlock): self._input_factor.register_matpower(exp, self._input_damping_func) self._output_factor.register_matpower(exp, self._output_damping_func) + def register_cholesky(self): + self._input_factor.register_cholesky(self._input_damping_func) + self._output_factor.register_cholesky(self._output_damping_func) + + def register_cholesky_inverse(self): + self._input_factor.register_cholesky_inverse(self._input_damping_func) + self._output_factor.register_cholesky_inverse(self._output_damping_func) + @property def _renorm_coeff(self): """Kronecker factor multiplier coefficient. @@ -687,17 +729,47 @@ class KroneckerProductFB(FisherBlock): """ return 1.0 - def multiply_matpower(self, vector, exp): + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = self._output_factor.right_multiply_matpower( - reshaped_vector, exp, self._output_damping_func) - reshaped_out = self._input_factor.left_multiply_matpower( - reshaped_out, exp, self._input_damping_func) - if self._renorm_coeff != 1.0: - renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype) - reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype) + reshaped_out = right_factor.matmul_right(reshaped_vector, + adjoint=transpose_right) + reshaped_out = left_factor.matmul(reshaped_out, + adjoint=transpose_left) + if extra_scale != 1.0: + reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) + def multiply_matpower(self, vector, exp): + left_factor = self._input_factor.get_matpower( + exp, self._input_damping_func) + right_factor = self._output_factor.get_matpower( + exp, self._output_damping_func) + extra_scale = float(self._renorm_coeff)**exp + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale) + + def multiply_cholesky(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky(self._input_damping_func) + right_factor = self._output_factor.get_cholesky(self._output_damping_func) + extra_scale = float(self._renorm_coeff)**0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky_inverse( + self._input_damping_func) + right_factor = self._output_factor.get_cholesky_inverse( + self._output_damping_func) + extra_scale = float(self._renorm_coeff)**-0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + def full_fisher_block(self): """Explicitly constructs the full Fisher block. @@ -706,8 +778,8 @@ class KroneckerProductFB(FisherBlock): Returns: The full Fisher block. """ - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() + left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() + right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() return self._renorm_coeff * utils.kronecker_product(left_factor, right_factor) @@ -796,7 +868,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): - """FisherBlock for convolutional layers using the basic KFC approx. + r"""FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. @@ -945,10 +1017,10 @@ class DepthwiseConvDiagonalFB(ConvDiagonalFB): self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) - def multiply_matpower(self, vector, exp): + def _multiply_matrix(self, matrix, vector): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower( - conv2d_vector, exp) + conv2d_result = super( + DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) @@ -1016,10 +1088,14 @@ class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) - def multiply_matpower(self, vector, exp): + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower( - conv2d_vector, exp) + conv2d_result = super( + DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( + left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, + transpose_left=transpose_left, transpose_right=transpose_right) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) @@ -1664,3 +1740,12 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, return utils.mat2d_to_layer_params(vector, Z) # pylint: enable=invalid-name + + def multiply_cholesky(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + + def multiply_cholesky_inverse(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 7988a3b92b..30f8a2a4b8 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -24,6 +24,7 @@ import contextlib import numpy as np import six +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops @@ -399,7 +400,7 @@ class FisherFactor(object): the cov update. Returns: - Tensor of same shape as self.get_cov_var(). + Tensor of same shape as self.get_cov(). """ pass @@ -448,78 +449,43 @@ class FisherFactor(object): """Create and return update ops corresponding to registered computations.""" pass - @abc.abstractmethod def get_cov(self): - """Get full covariance matrix. - - Returns: - Tensor of shape [n, n]. Represents all parameter-parameter correlations - captured by this FisherFactor. - """ - pass - - def get_cov_var(self): - """Get variable backing this FisherFactor. - - May or may not be the same as self.get_cov() - - Returns: - Variable of shape self._cov_shape. - """ return self._cov @abc.abstractmethod - def left_multiply_matpower(self, x, exp, damping_func): - """Left multiplies 'x' by matrix power of this factor (w/ damping applied). - - This calculation is essentially: - (C + damping * I)**exp * x - where * is matrix-multiplication, ** is matrix power, I is the identity - matrix, and C is the matrix represented by this factor. - - x can represent either a matrix or a vector. For some factors, 'x' might - represent a vector but actually be stored as a 2D matrix for convenience. - - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - exp: float. The matrix exponent to use. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). + def get_cov_as_linear_operator(self): + pass - Returns: - Tensor of same shape as 'x' representing the result of the multiplication. - """ + @abc.abstractmethod + def register_matpower(self, exp, damping_func): pass @abc.abstractmethod - def right_multiply_matpower(self, x, exp, damping_func): - """Right multiplies 'x' by matrix power of this factor (w/ damping applied). + def register_cholesky(self, damping_func): + pass - This calculation is essentially: - x * (C + damping * I)**exp - where * is matrix-multiplication, ** is matrix power, I is the identity - matrix, and C is the matrix represented by this factor. + @abc.abstractmethod + def register_cholesky_inverse(self, damping_func): + pass - Unlike left_multiply_matpower, x will always be a matrix. + @abc.abstractmethod + def get_matpower(self, exp, damping_func): + pass - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - exp: float. The matrix exponent to use. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). + @abc.abstractmethod + def get_cholesky(self, damping_func): + pass - Returns: - Tensor of same shape as 'x' representing the result of the multiplication. - """ + @abc.abstractmethod + def get_cholesky_inverse(self, damping_func): pass -class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses explicitly. +class DenseSquareMatrixFactor(FisherFactor): + """Base class for FisherFactors that are stored as dense square matrices. - This class explicitly calculates and stores inverses of covariance matrices - provided by the underlying FisherFactor implementation. It is assumed that - vectors can be represented as 2-D matrices. + This class explicitly calculates and stores inverses of their `cov` matrices, + which must be square dense matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -538,7 +504,19 @@ class InverseProvidingFactor(FisherFactor): self._eigendecomp = None self._damping_funcs_by_id = {} # {hashable: lambda} - super(InverseProvidingFactor, self).__init__() + self._cholesky_registrations = set() # { hashable } + self._cholesky_inverse_registrations = set() # { hashable } + + self._cholesky_by_damping = {} # { hashable: variable } + self._cholesky_inverse_by_damping = {} # { hashable: variable } + + super(DenseSquareMatrixFactor, self).__init__() + + def get_cov_as_linear_operator(self): + assert self.get_cov().shape.ndims == 2 + return lo.LinearOperatorFullMatrix(self.get_cov(), + is_self_adjoint=True, + is_square=True) def _register_damping(self, damping_func): damping_id = graph_func_to_id(damping_func) @@ -563,8 +541,6 @@ class InverseProvidingFactor(FisherFactor): be the damping value used. i.e. damping = damping_func(). """ if exp == 1.0: - # We don't register these. The user shouldn't even be calling this - # function with exp = 1.0. return damping_id = self._register_damping(damping_func) @@ -572,6 +548,38 @@ class InverseProvidingFactor(FisherFactor): if (exp, damping_id) not in self._matpower_registrations: self._matpower_registrations.add((exp, damping_id)) + def register_cholesky(self, damping_func): + """Registers a Cholesky factor to be maintained and served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_registrations: + self._cholesky_registrations.add(damping_id) + + def register_cholesky_inverse(self, damping_func): + """Registers an inverse Cholesky factor to be maintained/served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky_inverse. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_inverse_registrations: + self._cholesky_inverse_registrations.add(damping_id) + def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" @@ -589,6 +597,32 @@ class InverseProvidingFactor(FisherFactor): assert (exp, damping_id) not in self._matpower_by_exp_and_damping self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower + for damping_id in self._cholesky_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + chol = variable_scope.get_variable( + "cholesky_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_by_damping + self._cholesky_by_damping[damping_id] = chol + + for damping_id in self._cholesky_inverse_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + cholinv = variable_scope.get_variable( + "cholesky_inverse_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_inverse_by_damping + self._cholesky_inverse_by_damping[damping_id] = cholinv + def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] @@ -606,7 +640,8 @@ class InverseProvidingFactor(FisherFactor): # We precompute these so we don't need to evaluate them multiple times (for # each matrix power that uses them) - damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]() + damping_value_by_id = {damping_id: math_ops.cast( + self._damping_funcs_by_id[damping_id](), self._dtype) for damping_id in self._damping_funcs_by_id} if use_eig: @@ -627,29 +662,91 @@ class InverseProvidingFactor(FisherFactor): self._matpower_by_exp_and_damping.items()): assert exp == -1 damping = damping_value_by_id[damping_id] - ops.append(matpower.assign(utils.posdef_inv(self._cov, damping))) + ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) + + # TODO(b/77902055): If inverses are being computed with Cholesky's + # we can share the work. Instead this code currently just computes the + # Cholesky a second time. It does at least share work between requests for + # Cholesky's and Cholesky inverses with the same damping id. + for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): + cholesky_ops = [] + + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + + if damping_id in self._cholesky_by_damping: + cholesky = self._cholesky_by_damping[damping_id] + cholesky_ops.append(cholesky.assign(cholesky_value)) + + identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], + dtype=cholesky_value.dtype) + cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, + identity) + cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) + + ops.append(control_flow_ops.group(*cholesky_ops)) + + for damping_id, cholesky in self._cholesky_by_damping.items(): + if damping_id not in self._cholesky_inverse_by_damping: + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + ops.append(cholesky.assign(cholesky_value)) self._eigendecomp = False return ops def get_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests - damping_id = graph_func_to_id(damping_func) - return self._matpower_by_exp_and_damping[(-1, damping_id)] + return self.get_matpower(-1, damping_func) def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # get_cov(). + if exp != 1: + damping_id = graph_func_to_id(damping_func) + matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] + else: + matpower = self.get_cov() + identity = linalg_ops.eye(matpower.shape.as_list()[0], + dtype=matpower.dtype) + matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity + + assert matpower.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(matpower, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). damping_id = graph_func_to_id(damping_func) - return self._matpower_by_exp_and_damping[(exp, damping_id)] + cholesky = self._cholesky_by_damping[damping_id] + assert cholesky.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky, + is_non_singular=True, + is_square=True) + + def get_cholesky_inverse(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky_inv = self._cholesky_inverse_by_damping[damping_id] + assert cholesky_inv.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky_inv, + is_non_singular=True, + is_square=True) def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" # Unlike get_matpower this doesn't retrieve a stored variable, but instead # always computes a fresh version from the current value of get_cov(). if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) # The matrix self._cov is positive semidefinite by construction, but the # numerical eigenvalues could be negative due to numerical errors, so here @@ -660,45 +757,8 @@ class InverseProvidingFactor(FisherFactor): return self._eigendecomp - def get_cov(self): - # Variable contains full covariance matrix. - return self.get_cov_var() - - def left_multiply_matpower(self, x, exp, damping_func): - if isinstance(x, tf_ops.IndexedSlices): - raise ValueError("Left-multiply not yet supported for IndexedSlices.") - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - if exp == 1: - return math_ops.matmul(self.get_cov(), x) + damping_func() * x - - return math_ops.matmul(self.get_matpower(exp, damping_func), x) - - def right_multiply_matpower(self, x, exp, damping_func): - if isinstance(x, tf_ops.IndexedSlices): - if exp == 1: - n = self.get_cov().shape[0] - damped_cov = self.get_cov() + damping_func() * array_ops.eye(n) - return utils.matmul_sparse_dense(x, damped_cov) - - return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func)) - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - if exp == 1: - return math_ops.matmul(x, self.get_cov()) + damping_func() * x - - return math_ops.matmul(x, self.get_matpower(exp, damping_func)) - - -class FullFactor(InverseProvidingFactor): +class FullFactor(DenseSquareMatrixFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. Note that this uses the naive "square the sum estimator", and so is applicable @@ -757,41 +817,51 @@ class DiagonalFactor(FisherFactor): """ def __init__(self): - self._damping_funcs_by_id = {} # { hashable: lambda } super(DiagonalFactor, self).__init__() + def get_cov_as_linear_operator(self): + assert self._matrix_diagonal.shape.ndims == 1 + return lo.LinearOperatorDiag(self._matrix_diagonal, + is_self_adjoint=True, + is_square=True) + @property def _cov_initializer(self): return diagonal_covariance_initializer + @property + def _matrix_diagonal(self): + return array_ops.reshape(self.get_cov(), [-1]) + def make_inverse_update_ops(self): return [] def instantiate_inv_variables(self): pass - def get_cov(self): - # self.get_cov() could be any shape, but it must have one entry per - # parameter. Flatten it into a vector. - cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) - return array_ops.diag(cov_diag_vec) + def register_matpower(self, exp, damping_func): + pass - def left_multiply_matpower(self, x, exp, damping_func): - matpower = (self.get_cov_var() + damping_func())**exp + def register_cholesky(self, damping_func): + pass - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x) + def register_cholesky_inverse(self, damping_func): + pass - if x.shape != matpower.shape: - raise ValueError("x (%s) and cov (%s) must have same shape." % - (x, matpower)) - return matpower * x + def get_matpower(self, exp, damping_func): + matpower_diagonal = (self._matrix_diagonal + + math_ops.cast(damping_func(), self._dtype))**exp + return lo.LinearOperatorDiag(matpower_diagonal, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) - def right_multiply_matpower(self, x, exp, damping_func): - raise NotImplementedError("Only left-multiply is currently supported.") + def get_cholesky(self, damping_func): + return self.get_matpower(0.5, damping_func) - def register_matpower(self, exp, damping_func): - pass + def get_cholesky_inverse(self, damping_func): + return self.get_matpower(-0.5, damping_func) class NaiveDiagonalFactor(DiagonalFactor): @@ -1167,7 +1237,7 @@ class ConvDiagonalFactor(DiagonalFactor): return self._inputs[tower].device -class FullyConnectedKroneckerFactor(InverseProvidingFactor): +class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ @@ -1220,7 +1290,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): return self._tensors[0][tower].device -class ConvInputKroneckerFactor(InverseProvidingFactor): +class ConvInputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the input side of a convolutional layer. Estimates E[ a a^T ] where a is the inputs to a convolutional layer given @@ -1384,7 +1454,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): return self._inputs[tower].device -class ConvOutputKroneckerFactor(InverseProvidingFactor): +class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the output side of a convolutional layer. Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer @@ -1674,6 +1744,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): psi_var) in self._option1quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) invsqrtC0 = math_ops.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) @@ -1702,6 +1773,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): mu_var) in self._option2quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) # compute C0^(-1/2) invsqrtC0 = math_ops.matmul( diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py new file mode 100644 index 0000000000..61cb955ae8 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SmartMatrices definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.ops.linalg import linear_operator_util as lou + + +class LinearOperatorExtras(object): # pylint: disable=missing-docstring + + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + if isinstance(x, ops.IndexedSlices): + return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -2 if adjoint else -1 + arg_dim = -1 if adjoint_arg else -2 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + + if isinstance(x, ops.IndexedSlices): + return self._matmul_right_sparse( + x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -1 if adjoint else -2 + arg_dim = -2 if adjoint_arg else -1 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + +class LinearOperatorFullMatrix(LinearOperatorExtras, + linalg.LinearOperatorFullMatrix): + + # TODO(b/78117889) Remove this definition once core LinearOperator + # has _matmul_right. + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + return lou.matmul_with_broadcast( + x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + assert not adjoint and not adjoint_arg + return utils.matmul_sparse_dense(x, self._matrix) + + +class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring + linalg.LinearOperatorDiag): + + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + x = linalg_impl.adjoint(x) if adjoint_arg else x + return diag_mat * x + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + assert not adjoint_arg + return utils.matmul_diag_sparse(diag_mat, x) + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py index bf12dbaa9a..38a0e287a7 100644 --- a/tensorflow/contrib/kfac/python/ops/placement.py +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -35,7 +35,7 @@ def _make_thunk_on_device(func, device): class RoundRobinPlacementMixin(object): """Implements round robin placement strategy for ops and variables.""" - def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): + def __init__(self, cov_devices=None, inv_devices=None, **kwargs): """Initializes the RoundRobinPlacementMixin class. Args: @@ -45,11 +45,10 @@ class RoundRobinPlacementMixin(object): inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. - *args: - **kwargs: + **kwargs: Need something here? """ - super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) + super(RoundRobinPlacementMixin, self).__init__(**kwargs) self._cov_devices = cov_devices self._inv_devices = inv_devices diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index b6f42815e7..144295f4c7 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -235,6 +235,13 @@ posdef_eig_functions = { } +def cholesky(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return linalg_ops.cholesky(tensor + damping * identity) + + class SubGraph(object): """Defines a subgraph given by all the dependencies of a given set of outputs. """ @@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format): return data_format.endswith("C") -def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name +def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name """Computes matmul(A, B) where A is sparse, B is dense. Args: A: tf.IndexedSlices with dense shape [m, n]. B: tf.Tensor with shape [n, k]. name: str. Name of op. + transpose_a: Bool. If true we transpose A before multiplying it by B. + (Default: False) + transpose_b: Bool. If true we transpose B before multiplying it by A. + (Default: False) Returns: tf.IndexedSlices resulting from matmul(A, B). @@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name raise ValueError("A must represent a matrix. Found: %s." % A) if B.shape.ndims != 2: raise ValueError("B must be a matrix.") - new_values = math_ops.matmul(A.values, B) + new_values = math_ops.matmul( + A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) return ops.IndexedSlices( new_values, A.indices, |