diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-15 16:23:43 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-15 16:27:38 -0800 |
commit | f5b30312013df5b7bd3a50555b2facadd4aed204 (patch) | |
tree | 4e86630e16e5fc578e4749c1552c1b63272f42a9 /tensorflow/contrib/kfac | |
parent | e6f69c1161f24e80e71caeab6c721a98b208d5d7 (diff) |
K-FAC: Support for embedding layers, add FisherFactor.{multiply, multiply_inverse}.
PiperOrigin-RevId: 185920837
Diffstat (limited to 'tensorflow/contrib/kfac')
9 files changed, 662 insertions, 124 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 82accd57f0..fb4b3a241c 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops @@ -236,10 +237,10 @@ class NaiveDiagonalFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) -class FullyConnectedDiagonalFB(test.TestCase): +class FullyConnectedDiagonalFBTest(test.TestCase): def setUp(self): - super(FullyConnectedDiagonalFB, self).setUp() + super(FullyConnectedDiagonalFBTest, self).setUp() self.batch_size = 4 self.input_size = 6 @@ -375,6 +376,65 @@ class FullyConnectedDiagonalFB(test.TestCase): return multiply_result, multiply_inverse_result +class EmbeddingKFACFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_minibatch(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(([grads],), damping) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_minibatch(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(([grads],), damping) + + # Create a sparse update. + indices = array_ops.constant([1, 3, 4]) + values = array_ops.constant([[1.], [1.], [1.]]) + sparse_vector = ops.IndexedSlices( + values, indices, dense_shape=[vocab_size, 1]) + dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) + + # Compare Fisher-vector product against explicit result. + result = block.multiply_inverse(sparse_vector) + expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), + dense_vector) + + sess.run(tf_variables.global_variables_initializer()) + self.assertAlmostEqual( + sess.run(expected_result[1]), sess.run(result.values[0])) + self.assertAlmostEqual( + sess.run(expected_result[3]), sess.run(result.values[1])) + self.assertAlmostEqual( + sess.run(expected_result[4]), sess.run(result.values[2])) + + class FullyConnectedKFACBasicFBTest(test.TestCase): def testFullyConnectedKFACBasicFBInit(self): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 753378d9f4..66e18974ab 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -89,6 +89,21 @@ class FisherFactorTestingDummy(ff.FisherFactor): def make_inverse_update_ops(self): return [] + def get_cov(self): + return NotImplementedError + + def left_multiply(self, x, damping): + return NotImplementedError + + def right_multiply(self, x, damping): + return NotImplementedError + + def left_multiply_inverse(self, x, damping): + return NotImplementedError + + def right_multiply_inverse(self, x, damping): + return NotImplementedError + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -379,7 +394,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -387,7 +402,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - cov = factor.get_cov() + cov = factor.get_cov_var() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -402,6 +417,29 @@ class NaiveDiagonalFactorTest(test.TestCase): self.assertAllClose([[0.75], [1.5]], new_cov) +class EmbeddingInputKroneckerFactorTest(test.TestCase): + + def testInitialization(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + cov = factor.get_cov_var() + self.assertEqual(cov.shape.as_list(), [vocab_size]) + + def testCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + cov_update_op = factor.make_covariance_update_op(0.0) + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(cov_update_op) + self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) + + class FullyConnectedKroneckerFactorTest(test.TestCase): def _testFullyConnectedKroneckerFactorInit(self, diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 0d2fa706f5..cf38d28b43 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -92,10 +92,22 @@ def compute_pi_tracenorm(left_cov, right_cov): Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ + + def _trace(cov): + if len(cov.shape) == 1: + # Diagonal matrix. + return math_ops.reduce_sum(cov) + elif len(cov.shape) == 2: + # Full matrix. + return math_ops.trace(cov) + else: + raise ValueError( + "What's the trace of a Tensor of rank %d?" % len(cov.shape)) + # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] return math_ops.sqrt(left_norm / right_norm) @@ -201,15 +213,15 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_damped_inverse(self._damping) - out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) + vector_flat = utils.tensors_to_column(vector) + out_flat = self._factor.left_multiply_inverse( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def multiply(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = ( - math_ops.matmul(self._factor.get_cov(), vector_flat) + - self._damping * vector_flat) + out_flat = self._factor.left_multiply( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): @@ -265,16 +277,20 @@ class NaiveDiagonalFB(FisherBlock): def multiply_inverse(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat / (self._factor.get_cov() + self._damping) + print("vector_flat: %s" % vector_flat) + out_flat = self._factor.left_multiply_inverse( + vector_flat, self._damping) + print("out_flat: %s" % out_flat) return utils.column_to_tensors(vector, out_flat) def multiply(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat * (self._factor.get_cov() + self._damping) + out_flat = self._factor.left_multiply( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): - return array_ops.diag(array_ops.reshape(self._factor.get_cov(), (-1,))) + return self._factor.get_cov() def tensors_to_compute_grads(self): return self._params @@ -356,8 +372,9 @@ class FullyConnectedDiagonalFB(FisherBlock): Tensor of the same shape, corresponding to the inverse Fisher-vector product. """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + reshaped_vec = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply_inverse( + reshaped_vec, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): @@ -372,8 +389,9 @@ class FullyConnectedDiagonalFB(FisherBlock): Returns: Tensor of the same shape, corresponding to the Fisher-vector product. """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + reshaped_vec = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply( + reshaped_vec, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def tensors_to_compute_grads(self): @@ -468,12 +486,14 @@ class ConvDiagonalFB(FisherBlock): def multiply_inverse(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + reshaped_out = self._factor.left_multiply_inverse( + reshaped_vect, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + reshaped_out = self._factor.left_multiply( + reshaped_vect, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def tensors_to_compute_grads(self): @@ -533,28 +553,24 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_damped_inverse( - self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = math_ops.matmul(left_factor_inv, - math_ops.matmul(reshaped_vector, - right_factor_inv)) + reshaped_out = self._output_factor.right_multiply_inverse( + reshaped_vector, + self._output_damping) + reshaped_out = self._input_factor.left_multiply_inverse( + reshaped_out, self._input_damping) if self._renorm_coeff != 1.0: reshaped_out /= math_ops.cast( self._renorm_coeff, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = ( - math_ops.matmul(reshaped_vector, right_factor) + - self._output_damping * reshaped_vector) - reshaped_out = ( - math_ops.matmul(left_factor, reshaped_out) + - self._input_damping * reshaped_out) + reshaped_out = self._output_factor.right_multiply( + reshaped_vector, + self._output_damping) + reshaped_out = self._input_factor.left_multiply( + reshaped_out, self._input_damping) if self._renorm_coeff != 1.0: reshaped_out *= math_ops.cast( self._renorm_coeff, dtype=reshaped_out.dtype) @@ -574,6 +590,74 @@ class KroneckerProductFB(FisherBlock): right_factor) +class EmbeddingKFACFB(KroneckerProductFB): + """K-FAC FisherBlock for embedding layers. + + This FisherBlock is similar to EmbeddingKFACFB, except that its + input factor is approximated by a diagonal matrix. In the case that each + example references exactly one embedding, this approximation is exact. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size): + """Creates a EmbeddingKFACFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + """ + self._inputs = [] + self._outputs = [] + self._vocab_size = vocab_size + + super(EmbeddingKFACFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.EmbeddingInputKroneckerFactor, # + ((inputs,), self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + (grads_list,)) + self._register_damped_input_and_output_inverses(damping) + + def tensors_to_compute_grads(self): + return self._outputs + + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return len(self._inputs) + + class FullyConnectedKFACBasicFB(KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index ac39630920..c04cf727fa 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -29,6 +29,7 @@ _allowed_symbols = [ 'NaiveDiagonalFB', 'FullyConnectedDiagonalFB', 'KroneckerProductFB', + 'EmbeddingKFACFB', 'FullyConnectedKFACBasicFB', 'ConvKFCBasicFB', 'ConvDiagonalFB', @@ -36,7 +37,9 @@ _allowed_symbols = [ 'compute_pi_tracenorm', 'compute_pi_adjusted_damping', 'num_conv_locations', - 'normalize_damping' + 'normalize_damping', + 'LEFT_MULTIPLY', + 'RIGHT_MULTIPLY', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index bcba18ae14..603d8b8b21 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -25,13 +25,13 @@ import numpy as np import six from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -112,54 +112,6 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def extract_image_patches(image, ksizes, strides, padding, name=None): - """Extracts image patches for an N-dimensional convolution. - - This function is a compatibility wrapper over tf.extract_image_patches(), as - ExtractImagePatches isn't yet implemented in XLA. - - Args: - image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images. - All dimensions except 'batch' must be defined. - ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each - dimension. - strides: [stride_x, stride_y, ...]. Spatial stride for filter in each - dimension. - padding: str. "VALID" or "SAME". - name: str or None. name of Op. - - Returns: - result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels]. - Contains image patches to which conv kernel would be applied for each - output location. [out_x, out_y, ...] depends on padding. - """ - if not utils.on_tpu(): - return array_ops.extract_image_patches( - image, - ksizes=([1] + list(ksizes) + [1]), - strides=([1] + list(strides) + [1]), - rates=[1, 1, 1, 1], - padding=padding, - name=name) - - with tf_ops.name_scope(name, "extract_image_patches", - [image, ksizes, strides, padding]): - batch = image.shape.as_list()[0] - in_channels = image.shape.as_list()[-1] - - # Map each input feature to a location in the output. - out_channels = np.prod(ksizes) * in_channels - filters = linalg_ops.eye(out_channels), - filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels]) - - result = nn.convolution(image, filters, padding, strides=strides) - out_spatial = result.shape.as_list()[1:-1] - result = array_ops.reshape( - result, [batch or -1] + out_spatial + ksizes + [in_channels]) - - return result - - def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. @@ -259,12 +211,21 @@ def scalar_or_tensor_to_string(val): class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. - Note that for blocks that aren't based on approximations, a 'factor' can - be the entire block itself, as is the case for the diagonal and full - representations. + A FisherFactor represents part of an approximate Fisher Information matrix. + For example, one approximation to the Fisher uses the Kronecker product of two + FisherFactors A and B, F = kron(A, B). FisherFactors are composed with + FisherBlocks to construct a block-diagonal approximation to the full Fisher. + + FisherFactors are backed by a single, non-trainable variable that is updated + by running FisherFactor.make_covariance_update_op(). The shape and type of + this variable is implementation specific. - Subclasses must implement the _compute_new_cov method, and the _var_scope - and _cov_shape properties. + Note that for blocks that aren't based on approximations, a 'factor' can + be the entire block itself, as is the case for the diagonal and full + representations. + + Subclasses must implement the _compute_new_cov() method, and the _var_scope + and _cov_shape properties. """ def __init__(self): @@ -272,16 +233,21 @@ class FisherFactor(object): @abc.abstractproperty def _var_scope(self): + """Variable scope for this FisherFactor instance. + + Returns: + string that unique identifies this FisherFactor instance. + """ pass @abc.abstractproperty def _cov_shape(self): - """The shape of the cov matrix.""" + """The shape of the variable backing this FisherFactor.""" pass @abc.abstractproperty def _num_sources(self): - """The number of things to sum over when computing cov. + """The number of things to sum over when updating covariance variable. The default make_covariance_update_op function will call _compute_new_cov with indices ranging from 0 to _num_sources-1. The typical situation is @@ -293,10 +259,12 @@ class FisherFactor(object): @abc.abstractproperty def _dtype(self): + """dtype for variable backing this factor.""" pass @property def _cov_initializer(self): + """Function for initializing covariance variable.""" return covariance_initializer def instantiate_covariance(self): @@ -311,6 +279,15 @@ class FisherFactor(object): @abc.abstractmethod def _compute_new_cov(self, idx=0): + """Computes minibatch-estimated covariance for a single source. + + Args: + idx: int in [0, self._num_sources). Which source to use when estimating + covariance. + + Returns: + Tensor of same shape as self.get_cov_var(). + """ pass def make_covariance_update_op(self, ema_decay): @@ -343,14 +320,101 @@ class FisherFactor(object): """Create and return update ops corresponding to registered computations.""" pass + @abc.abstractmethod def get_cov(self): + """Get full covariance matrix. + + Returns: + Tensor of shape [n, n]. Represents all parameter-parameter correlations + captured by this FisherFactor. + """ + pass + + def get_cov_var(self): + """Get variable backing this FisherFactor. + + May or may not be the same as self.get_cov() + + Returns: + Variable of shape self._cov_shape. + """ return self._cov + @abc.abstractmethod + def left_multiply(self, x, damping): + """Multiplies 'x' by the damped covariance of this factor. + + Let C be the covariance matrix this factor represents, and + D = C + damping * I be its damped variant. This method calculates + matmul(D, vec(x)). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def right_multiply(self, x, damping): + """Multiplies 'x' by the damped covariance of this factor. + + Let C be the covariance matrix this factor represents, and + D = C + damping * I be its damped variant. This method calculates + matmul(vec(x), D). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def left_multiply_inverse(self, x, damping): + """Multiplies 'x' by damped inverse of this factor. + + Let C be the covariance matrix this factor represents and + E = inv(C + damping * I) be its damped inverse. This method calculates + matmul(E, vec(x)). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def right_multiply_inverse(self, x, damping): + """Multiplies 'x' by damped inverse of this factor. + + Let C be the covariance matrix this factor represents and + E = inv(C + damping * I) be its damped inverse. This method calculates + matmul(vec(x), E). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses, powers, etc of _cov. + """Base class for FisherFactors that maintain inverses explicitly. - Assumes that the _cov property is a square PSD matrix. + This class explicitly calculates and stores inverses of covariance matrices + provided by the underlying FisherFactor implementation. It is assumed that + vectors can be represented as 2-D matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -485,6 +549,61 @@ class InverseProvidingFactor(FisherFactor): def reset_eigendecomp(self): self._eigendecomp = None + def get_cov(self): + # Variable contains full covariance matrix. + return self.get_cov_var() + + def left_multiply(self, x, damping): + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping * array_ops.eye(n) + + if isinstance(x, tf_ops.IndexedSlices): + raise NotImplementedError( + "Left-multiply not yet supported for IndexedSlices.") + + if len(x.shape) != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(damped_cov, x) + + def right_multiply(self, x, damping): + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping * array_ops.eye(n) + + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_sparse_dense(x, damped_cov) + + if len(x.shape) != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(x, damped_cov) + + def left_multiply_inverse(self, x, damping): + if isinstance(x, tf_ops.IndexedSlices): + raise ValueError("Left-multiply not yet supported for IndexedSlices.") + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(self.get_damped_inverse(damping), x) + + def right_multiply_inverse(self, x, damping): + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_sparse_dense(x, self.get_damped_inverse(damping)) + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(x, self.get_damped_inverse(damping)) + class FullFactor(InverseProvidingFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. @@ -530,7 +649,11 @@ class FullFactor(InverseProvidingFactor): class DiagonalFactor(FisherFactor): - """A base class for FisherFactors that use diagonal approximations.""" + """A base class for FisherFactors that use diagonal approximations. + + A DiagonalFactor's covariance variable can be of any shape, but must contain + exactly one entry per parameter. + """ def __init__(self): super(DiagonalFactor, self).__init__() @@ -542,6 +665,45 @@ class DiagonalFactor(FisherFactor): def make_inverse_update_ops(self): return [] + def get_cov(self): + # self.get_cov() could be any shape, but it must have one entry per + # parameter. Flatten it into a vector. + cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) + return array_ops.diag(cov_diag_vec) + + def left_multiply(self, x, damping): + damped_cov = self.get_cov_var() + damping + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_diag_sparse(array_ops.reshape(damped_cov, [-1]), x) + + if x.shape != damped_cov.shape: + raise ValueError("x (%s) and cov (%s) must have same shape." % + (x, damped_cov)) + + return damped_cov * x + + def right_multiply(self, x, damping): + raise NotImplementedError("Only left-multiply is currently supported.") + + def left_multiply_inverse(self, x, damping): + inverse = 1. / (self.get_cov_var() + damping) + + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_diag_sparse(array_ops.reshape(inverse, [-1]), x) + + if x.shape != inverse.shape: + raise ValueError("x (%s) and cov (%s) must have same shape." % + (x, inverse)) + + return inverse * x + + def right_multiply_inverse(self, x, damping): + raise NotImplementedError("Only left-multiply is currently supported.") + + def register_damped_inverse(self, damping): + # DiagonalFactors don't keep explicit inverses. + pass + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -553,6 +715,14 @@ class NaiveDiagonalFactor(DiagonalFactor): def __init__(self, params_grads, batch_size): + """Initializes NaiveDiagonalFactor instance. + + Args: + params_grads: Sequence of Tensors, each with same shape as parameters this + FisherFactor corresponds to. For example, the gradient of the loss with + respect to parameters. + batch_size: int or 0-D Tensor. Size + """ self._params_grads = tuple(utils.ensure_sequence(params_grad) for params_grad in params_grads) self._batch_size = batch_size @@ -567,7 +737,7 @@ class NaiveDiagonalFactor(DiagonalFactor): def _cov_shape(self): size = sum(param_grad.shape.num_elements() for param_grad in self._params_grads[0]) - return (size, 1) + return [size, 1] @property def _num_sources(self): @@ -584,6 +754,84 @@ class NaiveDiagonalFactor(DiagonalFactor): self._batch_size, params_grads_flat.dtype)) +class EmbeddingInputKroneckerFactor(DiagonalFactor): + r"""FisherFactor for input to an embedding layer. + + Given input_ids = [batch_size, input_size] representing indices into an + [vocab_size, embedding_size] embedding matrix, approximate input covariance by + a diagonal matrix, + + Cov(input_ids, input_ids) = + (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). + + where n_hot() constructs an n-hot binary vector and diag() constructs a + diagonal matrix of size [vocab_size, vocab_size]. + """ + + def __init__(self, input_ids, vocab_size, dtype=None): + """Instantiate EmbeddingInputKroneckerFactor. + + Args: + input_ids: Tuple of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. + vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. + dtype: dtype for covariance statistics. Must be a floating point type. + Defaults to float32. + """ + self._input_ids = input_ids + self._vocab_size = vocab_size + self._cov_dtype = dtype or dtypes.float32 + + super(EmbeddingInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diag_embedding/" + scope_string_from_params(self._input_ids) + + @property + def _cov_shape(self): + return [self._vocab_size] + + @property + def _num_sources(self): + return len(self._input_ids) + + @property + def _dtype(self): + return self._cov_dtype + + def _compute_new_cov(self, idx=0): + with maybe_colocate_with(self._input_ids): + input_ids = self._input_ids[idx] + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) + + batch_size = array_ops.shape(input_ids)[0] + + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] + + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + class FullyConnectedDiagonalFactor(DiagonalFactor): r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. @@ -623,8 +871,9 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): @property def _cov_shape(self): - return [self._inputs.shape[1] + self._has_bias, - self._outputs_grads[0].shape[1]] + input_size = self._inputs.shape[1] + self._has_bias + output_size = self._outputs_grads[0].shape[1] + return [input_size, output_size] @property def _num_sources(self): @@ -717,10 +966,11 @@ class ConvDiagonalFactor(DiagonalFactor): # TODO(b/64144716): there is potential here for a big savings in terms # of memory use. - patches = extract_image_patches( + patches = array_ops.extract_image_patches( self._inputs, - ksizes=[filter_height, filter_width], - strides=self._strides[1:-1], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], padding=self._padding) if self._has_bias: @@ -864,10 +1114,11 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. - patches = extract_image_patches( + patches = array_ops.extract_image_patches( self._inputs, - ksizes=[filter_height, filter_width], - strides=self._strides[1:-1], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], padding=self._padding) flatten_size = (filter_height * filter_width * in_channels) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index ad93919149..2d8e378a93 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -24,26 +24,15 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ - "inverse_initializer", - "covariance_initializer", - "diagonal_covariance_initializer", - "scope_string_from_params", - "scope_string_from_name", - "scalar_or_tensor_to_string", - "FisherFactor", - "InverseProvidingFactor", - "FullFactor", - "DiagonalFactor", - "NaiveDiagonalFactor", - "FullyConnectedDiagonalFactor", - "FullyConnectedKroneckerFactor", - "ConvInputKroneckerFactor", - "ConvOutputKroneckerFactor", - "ConvDiagonalFactor", - "set_global_constants", - "maybe_colocate_with", - "compute_cov", - "append_homog" + "inverse_initializer", "covariance_initializer", + "diagonal_covariance_initializer", "scope_string_from_params", + "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", + "InverseProvidingFactor", "FullFactor", "DiagonalFactor", + "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", + "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", + "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", + "compute_cov", "append_homog" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 8d450f04f3..ce9005b9ce 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -143,6 +143,7 @@ class LayerCollection(object): self._loss_dict = {} # {str: LossFunction} self._subgraph = None self._default_generic_approximation = APPROX_FULL_NAME + self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( @@ -179,6 +180,17 @@ class LayerCollection(object): return self._linked_parameters @property + def default_embedding_approximation(self): + return self._default_embedding_approximation + + def set_default_embedding_approximation(self, value): + if value != APPROX_KRONECKER_NAME: + raise ValueError( + "{} is not a valid approximation for embedding variables.".format( + value)) + self._default_embedding_approximation = value + + @property def default_generic_approximation(self): return self._default_generic_approximation @@ -417,6 +429,46 @@ class LayerCollection(object): else: return None + def register_embedding(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a fully connnected layer. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices + into embedding matrix. + outputs: Tensor of shape [batch_size, output_size]. Outputs + produced by layer. + approx: str. Must be "kron". + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_embedding_approximation + + if approx != APPROX_KRONECKER_NAME: + raise ValueError("Bad value {} for approx.".format(approx)) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + + vocab_size = int(params.shape[0]) + block = self.register_block( + params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) + block.register_additional_minibatch(inputs, outputs) + def register_fully_connected(self, params, inputs, diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index e89508fa46..f5bd97cb4e 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -144,7 +144,9 @@ def layer_params_to_mat2d(vector): [-1, w_part.shape.as_list()[-1]]) return array_ops.concat( (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) - else: + elif isinstance(vector, ops.IndexedSlices): + return vector + else: # Tensor or Tensor-like. return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) @@ -163,6 +165,11 @@ def mat2d_to_layer_params(vector_template, mat2d): if isinstance(vector_template, (tuple, list)): w_part, b_part = mat2d[:-1], mat2d[-1] return array_ops.reshape(w_part, vector_template[0].shape), b_part + elif isinstance(vector_template, ops.IndexedSlices): + if not isinstance(mat2d, ops.IndexedSlices): + raise TypeError( + "If vector_template is an IndexedSlices, so should mat2d.") + return mat2d else: return array_ops.reshape(mat2d, vector_template.shape) @@ -420,5 +427,57 @@ def batch_execute(global_step, thunks, batch_size, name=None): return result +def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is sparse, B is dense. + + Args: + A: tf.IndexedSlices with dense shape [m, n]. + B: tf.Tensor with shape [n, k]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A doesn't represent a matrix. + ValueError: If B is not rank-2. + """ + with ops.name_scope(name, "matmul_sparse_dense", [A, B]): + if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: + raise ValueError("A must represent a matrix. Found: %s." % A) + if B.shape.ndims != 2: + raise ValueError("B must be a matrix.") + new_values = math_ops.matmul(A.values, B) + return ops.IndexedSlices( + new_values, + A.indices, + dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) + + +def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. + + Args: + A_diag: diagonal entries of matrix A of shape [m, m]. + B: tf.IndexedSlices. Represents matrix of shape [m, n]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A_diag is not rank-1. + ValueError: If B doesn't represent a matrix. + """ + with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): + A_diag = ops.convert_to_tensor(A_diag) + if A_diag.shape.ndims != 1: + raise ValueError("A_diag must be a rank-1 Tensor.") + if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: + raise ValueError("B must represent a matrix. Found: %s." % B) + a = array_ops.gather(A_diag, B.indices) + a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) + return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index fe8e39c212..8e424a7946 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -40,6 +40,8 @@ _allowed_symbols = [ "fwd_gradients", "ensure_sequence", "batch_execute", + "matmul_sparse_dense", + "matmul_diag_sparse", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |