aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-15 16:23:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-15 16:27:38 -0800
commitf5b30312013df5b7bd3a50555b2facadd4aed204 (patch)
tree4e86630e16e5fc578e4749c1552c1b63272f42a9 /tensorflow/contrib/kfac
parente6f69c1161f24e80e71caeab6c721a98b208d5d7 (diff)
K-FAC: Support for embedding layers, add FisherFactor.{multiply, multiply_inverse}.
PiperOrigin-RevId: 185920837
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py64
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py42
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py144
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py5
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py387
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py29
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py52
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py61
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py2
9 files changed, 662 insertions, 124 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
index 82accd57f0..fb4b3a241c 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
@@ -236,10 +237,10 @@ class NaiveDiagonalFBTest(test.TestCase):
self.assertAllClose(output_flat, explicit)
-class FullyConnectedDiagonalFB(test.TestCase):
+class FullyConnectedDiagonalFBTest(test.TestCase):
def setUp(self):
- super(FullyConnectedDiagonalFB, self).setUp()
+ super(FullyConnectedDiagonalFBTest, self).setUp()
self.batch_size = 4
self.input_size = 6
@@ -375,6 +376,65 @@ class FullyConnectedDiagonalFB(test.TestCase):
return multiply_result, multiply_inverse_result
+class EmbeddingKFACFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_minibatch(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(([grads],), damping)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_minibatch(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(([grads],), damping)
+
+ # Create a sparse update.
+ indices = array_ops.constant([1, 3, 4])
+ values = array_ops.constant([[1.], [1.], [1.]])
+ sparse_vector = ops.IndexedSlices(
+ values, indices, dense_shape=[vocab_size, 1])
+ dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
+
+ # Compare Fisher-vector product against explicit result.
+ result = block.multiply_inverse(sparse_vector)
+ expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
+ dense_vector)
+
+ sess.run(tf_variables.global_variables_initializer())
+ self.assertAlmostEqual(
+ sess.run(expected_result[1]), sess.run(result.values[0]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[3]), sess.run(result.values[1]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[4]), sess.run(result.values[2]))
+
+
class FullyConnectedKFACBasicFBTest(test.TestCase):
def testFullyConnectedKFACBasicFBInit(self):
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 753378d9f4..66e18974ab 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -89,6 +89,21 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def make_inverse_update_ops(self):
return []
+ def get_cov(self):
+ return NotImplementedError
+
+ def left_multiply(self, x, damping):
+ return NotImplementedError
+
+ def right_multiply(self, x, damping):
+ return NotImplementedError
+
+ def left_multiply_inverse(self, x, damping):
+ return NotImplementedError
+
+ def right_multiply_inverse(self, x, damping):
+ return NotImplementedError
+
class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor):
"""Dummy class to test the non-abstract methods on ff.InverseProvidingFactor.
@@ -379,7 +394,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
- self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
+ self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
def testNaiveDiagonalFactorInitFloat64(self):
with tf_ops.Graph().as_default():
@@ -387,7 +402,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
- cov = factor.get_cov()
+ cov = factor.get_cov_var()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 1], cov.get_shape().as_list())
@@ -402,6 +417,29 @@ class NaiveDiagonalFactorTest(test.TestCase):
self.assertAllClose([[0.75], [1.5]], new_cov)
+class EmbeddingInputKroneckerFactorTest(test.TestCase):
+
+ def testInitialization(self):
+ with tf_ops.Graph().as_default():
+ input_ids = array_ops.constant([[0], [1], [4]])
+ vocab_size = 5
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ cov = factor.get_cov_var()
+ self.assertEqual(cov.shape.as_list(), [vocab_size])
+
+ def testCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default():
+ input_ids = array_ops.constant([[0], [1], [4]])
+ vocab_size = 5
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ cov_update_op = factor.make_covariance_update_op(0.0)
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(cov_update_op)
+ self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
+
+
class FullyConnectedKroneckerFactorTest(test.TestCase):
def _testFullyConnectedKroneckerFactorInit(self,
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 0d2fa706f5..cf38d28b43 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -92,10 +92,22 @@ def compute_pi_tracenorm(left_cov, right_cov):
Returns:
The computed scalar constant pi for these Kronecker Factors (as a Tensor).
"""
+
+ def _trace(cov):
+ if len(cov.shape) == 1:
+ # Diagonal matrix.
+ return math_ops.reduce_sum(cov)
+ elif len(cov.shape) == 2:
+ # Full matrix.
+ return math_ops.trace(cov)
+ else:
+ raise ValueError(
+ "What's the trace of a Tensor of rank %d?" % len(cov.shape))
+
# Instead of dividing by the dim of the norm, we multiply by the dim of the
# other norm. This works out the same in the ratio.
- left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0]
- right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0]
+ left_norm = _trace(left_cov) * right_cov.shape.as_list()[0]
+ right_norm = _trace(right_cov) * left_cov.shape.as_list()[0]
return math_ops.sqrt(left_norm / right_norm)
@@ -201,15 +213,15 @@ class FullFB(FisherBlock):
self._factor.register_damped_inverse(damping)
def multiply_inverse(self, vector):
- inverse = self._factor.get_damped_inverse(self._damping)
- out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector))
+ vector_flat = utils.tensors_to_column(vector)
+ out_flat = self._factor.left_multiply_inverse(
+ vector_flat, self._damping)
return utils.column_to_tensors(vector, out_flat)
def multiply(self, vector):
vector_flat = utils.tensors_to_column(vector)
- out_flat = (
- math_ops.matmul(self._factor.get_cov(), vector_flat) +
- self._damping * vector_flat)
+ out_flat = self._factor.left_multiply(
+ vector_flat, self._damping)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
@@ -265,16 +277,20 @@ class NaiveDiagonalFB(FisherBlock):
def multiply_inverse(self, vector):
vector_flat = utils.tensors_to_column(vector)
- out_flat = vector_flat / (self._factor.get_cov() + self._damping)
+ print("vector_flat: %s" % vector_flat)
+ out_flat = self._factor.left_multiply_inverse(
+ vector_flat, self._damping)
+ print("out_flat: %s" % out_flat)
return utils.column_to_tensors(vector, out_flat)
def multiply(self, vector):
vector_flat = utils.tensors_to_column(vector)
- out_flat = vector_flat * (self._factor.get_cov() + self._damping)
+ out_flat = self._factor.left_multiply(
+ vector_flat, self._damping)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
- return array_ops.diag(array_ops.reshape(self._factor.get_cov(), (-1,)))
+ return self._factor.get_cov()
def tensors_to_compute_grads(self):
return self._params
@@ -356,8 +372,9 @@ class FullyConnectedDiagonalFB(FisherBlock):
Tensor of the same shape, corresponding to the inverse Fisher-vector
product.
"""
- reshaped_vect = utils.layer_params_to_mat2d(vector)
- reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping)
+ reshaped_vec = utils.layer_params_to_mat2d(vector)
+ reshaped_out = self._factor.left_multiply_inverse(
+ reshaped_vec, self._damping)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply(self, vector):
@@ -372,8 +389,9 @@ class FullyConnectedDiagonalFB(FisherBlock):
Returns:
Tensor of the same shape, corresponding to the Fisher-vector product.
"""
- reshaped_vect = utils.layer_params_to_mat2d(vector)
- reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping)
+ reshaped_vec = utils.layer_params_to_mat2d(vector)
+ reshaped_out = self._factor.left_multiply(
+ reshaped_vec, self._damping)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def tensors_to_compute_grads(self):
@@ -468,12 +486,14 @@ class ConvDiagonalFB(FisherBlock):
def multiply_inverse(self, vector):
reshaped_vect = utils.layer_params_to_mat2d(vector)
- reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping)
+ reshaped_out = self._factor.left_multiply_inverse(
+ reshaped_vect, self._damping)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply(self, vector):
reshaped_vect = utils.layer_params_to_mat2d(vector)
- reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping)
+ reshaped_out = self._factor.left_multiply(
+ reshaped_vect, self._damping)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def tensors_to_compute_grads(self):
@@ -533,28 +553,24 @@ class KroneckerProductFB(FisherBlock):
return 1.0
def multiply_inverse(self, vector):
- left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping)
- right_factor_inv = self._output_factor.get_damped_inverse(
- self._output_damping)
reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = math_ops.matmul(left_factor_inv,
- math_ops.matmul(reshaped_vector,
- right_factor_inv))
+ reshaped_out = self._output_factor.right_multiply_inverse(
+ reshaped_vector,
+ self._output_damping)
+ reshaped_out = self._input_factor.left_multiply_inverse(
+ reshaped_out, self._input_damping)
if self._renorm_coeff != 1.0:
reshaped_out /= math_ops.cast(
self._renorm_coeff, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply(self, vector):
- left_factor = self._input_factor.get_cov()
- right_factor = self._output_factor.get_cov()
reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = (
- math_ops.matmul(reshaped_vector, right_factor) +
- self._output_damping * reshaped_vector)
- reshaped_out = (
- math_ops.matmul(left_factor, reshaped_out) +
- self._input_damping * reshaped_out)
+ reshaped_out = self._output_factor.right_multiply(
+ reshaped_vector,
+ self._output_damping)
+ reshaped_out = self._input_factor.left_multiply(
+ reshaped_out, self._input_damping)
if self._renorm_coeff != 1.0:
reshaped_out *= math_ops.cast(
self._renorm_coeff, dtype=reshaped_out.dtype)
@@ -574,6 +590,74 @@ class KroneckerProductFB(FisherBlock):
right_factor)
+class EmbeddingKFACFB(KroneckerProductFB):
+ """K-FAC FisherBlock for embedding layers.
+
+ This FisherBlock is similar to EmbeddingKFACFB, except that its
+ input factor is approximated by a diagonal matrix. In the case that each
+ example references exactly one embedding, this approximation is exact.
+
+ Does not support bias parameters.
+ """
+
+ def __init__(self, layer_collection, vocab_size):
+ """Creates a EmbeddingKFACFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ vocab_size: int. Size of vocabulary for this embedding layer.
+ """
+ self._inputs = []
+ self._outputs = []
+ self._vocab_size = vocab_size
+
+ super(EmbeddingKFACFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ """Instantiate Kronecker Factors for this FisherBlock.
+
+ Args:
+ grads_list: List of list of Tensors. grads_list[i][j] is the
+ gradient of the loss with respect to 'outputs' from source 'i' and
+ tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
+ damping: 0-D Tensor or float. 'damping' * identity is approximately added
+ to this FisherBlock's Fisher approximation.
+ """
+ # TODO(b/68033310): Validate which of,
+ # (1) summing on a single device (as below), or
+ # (2) on each device in isolation and aggregating
+ # is faster.
+ inputs = _concat_along_batch_dim(self._inputs)
+ grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor( #
+ fisher_factors.EmbeddingInputKroneckerFactor, #
+ ((inputs,), self._vocab_size))
+ self._output_factor = self._layer_collection.make_or_get_factor( #
+ fisher_factors.FullyConnectedKroneckerFactor, #
+ (grads_list,))
+ self._register_damped_input_and_output_inverses(damping)
+
+ def tensors_to_compute_grads(self):
+ return self._outputs
+
+ def register_additional_minibatch(self, inputs, outputs):
+ """Registers an additional minibatch to the FisherBlock.
+
+ Args:
+ inputs: Tensor of shape [batch_size, input_size]. Inputs to the
+ matrix-multiply.
+ outputs: Tensor of shape [batch_size, output_size]. Layer preactivations.
+ """
+ self._inputs.append(inputs)
+ self._outputs.append(outputs)
+
+ @property
+ def num_registered_minibatches(self):
+ return len(self._inputs)
+
+
class FullyConnectedKFACBasicFB(KroneckerProductFB):
"""K-FAC FisherBlock for fully-connected (dense) layers.
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
index ac39630920..c04cf727fa 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
@@ -29,6 +29,7 @@ _allowed_symbols = [
'NaiveDiagonalFB',
'FullyConnectedDiagonalFB',
'KroneckerProductFB',
+ 'EmbeddingKFACFB',
'FullyConnectedKFACBasicFB',
'ConvKFCBasicFB',
'ConvDiagonalFB',
@@ -36,7 +37,9 @@ _allowed_symbols = [
'compute_pi_tracenorm',
'compute_pi_adjusted_damping',
'num_conv_locations',
- 'normalize_damping'
+ 'normalize_damping',
+ 'LEFT_MULTIPLY',
+ 'RIGHT_MULTIPLY',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index bcba18ae14..603d8b8b21 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -25,13 +25,13 @@ import numpy as np
import six
from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -112,54 +112,6 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di
return array_ops.ones(shape, dtype)
-def extract_image_patches(image, ksizes, strides, padding, name=None):
- """Extracts image patches for an N-dimensional convolution.
-
- This function is a compatibility wrapper over tf.extract_image_patches(), as
- ExtractImagePatches isn't yet implemented in XLA.
-
- Args:
- image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images.
- All dimensions except 'batch' must be defined.
- ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each
- dimension.
- strides: [stride_x, stride_y, ...]. Spatial stride for filter in each
- dimension.
- padding: str. "VALID" or "SAME".
- name: str or None. name of Op.
-
- Returns:
- result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels].
- Contains image patches to which conv kernel would be applied for each
- output location. [out_x, out_y, ...] depends on padding.
- """
- if not utils.on_tpu():
- return array_ops.extract_image_patches(
- image,
- ksizes=([1] + list(ksizes) + [1]),
- strides=([1] + list(strides) + [1]),
- rates=[1, 1, 1, 1],
- padding=padding,
- name=name)
-
- with tf_ops.name_scope(name, "extract_image_patches",
- [image, ksizes, strides, padding]):
- batch = image.shape.as_list()[0]
- in_channels = image.shape.as_list()[-1]
-
- # Map each input feature to a location in the output.
- out_channels = np.prod(ksizes) * in_channels
- filters = linalg_ops.eye(out_channels),
- filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels])
-
- result = nn.convolution(image, filters, padding, strides=strides)
- out_spatial = result.shape.as_list()[1:-1]
- result = array_ops.reshape(
- result, [batch or -1] + out_spatial + ksizes + [in_channels])
-
- return result
-
-
def compute_cov(tensor, tensor_right=None, normalizer=None):
"""Compute the empirical second moment of the rows of a 2D Tensor.
@@ -259,12 +211,21 @@ def scalar_or_tensor_to_string(val):
class FisherFactor(object):
"""Base class for objects modeling factors of approximate Fisher blocks.
- Note that for blocks that aren't based on approximations, a 'factor' can
- be the entire block itself, as is the case for the diagonal and full
- representations.
+ A FisherFactor represents part of an approximate Fisher Information matrix.
+ For example, one approximation to the Fisher uses the Kronecker product of two
+ FisherFactors A and B, F = kron(A, B). FisherFactors are composed with
+ FisherBlocks to construct a block-diagonal approximation to the full Fisher.
+
+ FisherFactors are backed by a single, non-trainable variable that is updated
+ by running FisherFactor.make_covariance_update_op(). The shape and type of
+ this variable is implementation specific.
- Subclasses must implement the _compute_new_cov method, and the _var_scope
- and _cov_shape properties.
+ Note that for blocks that aren't based on approximations, a 'factor' can
+ be the entire block itself, as is the case for the diagonal and full
+ representations.
+
+ Subclasses must implement the _compute_new_cov() method, and the _var_scope
+ and _cov_shape properties.
"""
def __init__(self):
@@ -272,16 +233,21 @@ class FisherFactor(object):
@abc.abstractproperty
def _var_scope(self):
+ """Variable scope for this FisherFactor instance.
+
+ Returns:
+ string that unique identifies this FisherFactor instance.
+ """
pass
@abc.abstractproperty
def _cov_shape(self):
- """The shape of the cov matrix."""
+ """The shape of the variable backing this FisherFactor."""
pass
@abc.abstractproperty
def _num_sources(self):
- """The number of things to sum over when computing cov.
+ """The number of things to sum over when updating covariance variable.
The default make_covariance_update_op function will call _compute_new_cov
with indices ranging from 0 to _num_sources-1. The typical situation is
@@ -293,10 +259,12 @@ class FisherFactor(object):
@abc.abstractproperty
def _dtype(self):
+ """dtype for variable backing this factor."""
pass
@property
def _cov_initializer(self):
+ """Function for initializing covariance variable."""
return covariance_initializer
def instantiate_covariance(self):
@@ -311,6 +279,15 @@ class FisherFactor(object):
@abc.abstractmethod
def _compute_new_cov(self, idx=0):
+ """Computes minibatch-estimated covariance for a single source.
+
+ Args:
+ idx: int in [0, self._num_sources). Which source to use when estimating
+ covariance.
+
+ Returns:
+ Tensor of same shape as self.get_cov_var().
+ """
pass
def make_covariance_update_op(self, ema_decay):
@@ -343,14 +320,101 @@ class FisherFactor(object):
"""Create and return update ops corresponding to registered computations."""
pass
+ @abc.abstractmethod
def get_cov(self):
+ """Get full covariance matrix.
+
+ Returns:
+ Tensor of shape [n, n]. Represents all parameter-parameter correlations
+ captured by this FisherFactor.
+ """
+ pass
+
+ def get_cov_var(self):
+ """Get variable backing this FisherFactor.
+
+ May or may not be the same as self.get_cov()
+
+ Returns:
+ Variable of shape self._cov_shape.
+ """
return self._cov
+ @abc.abstractmethod
+ def left_multiply(self, x, damping):
+ """Multiplies 'x' by the damped covariance of this factor.
+
+ Let C be the covariance matrix this factor represents, and
+ D = C + damping * I be its damped variant. This method calculates
+ matmul(D, vec(x)).
+
+ Args:
+ x: Tensor. Represents a single vector. Shape depends on implementation.
+ damping: 0-D Tensor. Damping to add to C's diagonal.
+
+ Returns:
+ Tensor of same shape as 'x'.
+ """
+ pass
+
+ @abc.abstractmethod
+ def right_multiply(self, x, damping):
+ """Multiplies 'x' by the damped covariance of this factor.
+
+ Let C be the covariance matrix this factor represents, and
+ D = C + damping * I be its damped variant. This method calculates
+ matmul(vec(x), D).
+
+ Args:
+ x: Tensor. Represents a single vector. Shape depends on implementation.
+ damping: 0-D Tensor. Damping to add to C's diagonal.
+
+ Returns:
+ Tensor of same shape as 'x'.
+ """
+ pass
+
+ @abc.abstractmethod
+ def left_multiply_inverse(self, x, damping):
+ """Multiplies 'x' by damped inverse of this factor.
+
+ Let C be the covariance matrix this factor represents and
+ E = inv(C + damping * I) be its damped inverse. This method calculates
+ matmul(E, vec(x)).
+
+ Args:
+ x: Tensor. Represents a single vector. Shape depends on implementation.
+ damping: 0-D Tensor. Damping to add to C's diagonal.
+
+ Returns:
+ Tensor of same shape as 'x'.
+ """
+ pass
+
+ @abc.abstractmethod
+ def right_multiply_inverse(self, x, damping):
+ """Multiplies 'x' by damped inverse of this factor.
+
+ Let C be the covariance matrix this factor represents and
+ E = inv(C + damping * I) be its damped inverse. This method calculates
+ matmul(vec(x), E).
+
+ Args:
+ x: Tensor. Represents a single vector. Shape depends on implementation.
+ damping: 0-D Tensor. Damping to add to C's diagonal.
+
+ Returns:
+ Tensor of same shape as 'x'.
+ """
+ pass
+
class InverseProvidingFactor(FisherFactor):
- """Base class for FisherFactors that maintain inverses, powers, etc of _cov.
+ """Base class for FisherFactors that maintain inverses explicitly.
- Assumes that the _cov property is a square PSD matrix.
+ This class explicitly calculates and stores inverses of covariance matrices
+ provided by the underlying FisherFactor implementation. It is assumed that
+ vectors can be represented as 2-D matrices.
Subclasses must implement the _compute_new_cov method, and the _var_scope and
_cov_shape properties.
@@ -485,6 +549,61 @@ class InverseProvidingFactor(FisherFactor):
def reset_eigendecomp(self):
self._eigendecomp = None
+ def get_cov(self):
+ # Variable contains full covariance matrix.
+ return self.get_cov_var()
+
+ def left_multiply(self, x, damping):
+ n = self.get_cov().shape[0]
+ damped_cov = self.get_cov() + damping * array_ops.eye(n)
+
+ if isinstance(x, tf_ops.IndexedSlices):
+ raise NotImplementedError(
+ "Left-multiply not yet supported for IndexedSlices.")
+
+ if len(x.shape) != 2:
+ raise ValueError(
+ "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
+ % (x,))
+
+ return math_ops.matmul(damped_cov, x)
+
+ def right_multiply(self, x, damping):
+ n = self.get_cov().shape[0]
+ damped_cov = self.get_cov() + damping * array_ops.eye(n)
+
+ if isinstance(x, tf_ops.IndexedSlices):
+ return utils.matmul_sparse_dense(x, damped_cov)
+
+ if len(x.shape) != 2:
+ raise ValueError(
+ "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
+ % (x,))
+
+ return math_ops.matmul(x, damped_cov)
+
+ def left_multiply_inverse(self, x, damping):
+ if isinstance(x, tf_ops.IndexedSlices):
+ raise ValueError("Left-multiply not yet supported for IndexedSlices.")
+
+ if x.shape.ndims != 2:
+ raise ValueError(
+ "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
+ % (x,))
+
+ return math_ops.matmul(self.get_damped_inverse(damping), x)
+
+ def right_multiply_inverse(self, x, damping):
+ if isinstance(x, tf_ops.IndexedSlices):
+ return utils.matmul_sparse_dense(x, self.get_damped_inverse(damping))
+
+ if x.shape.ndims != 2:
+ raise ValueError(
+ "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
+ % (x,))
+
+ return math_ops.matmul(x, self.get_damped_inverse(damping))
+
class FullFactor(InverseProvidingFactor):
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
@@ -530,7 +649,11 @@ class FullFactor(InverseProvidingFactor):
class DiagonalFactor(FisherFactor):
- """A base class for FisherFactors that use diagonal approximations."""
+ """A base class for FisherFactors that use diagonal approximations.
+
+ A DiagonalFactor's covariance variable can be of any shape, but must contain
+ exactly one entry per parameter.
+ """
def __init__(self):
super(DiagonalFactor, self).__init__()
@@ -542,6 +665,45 @@ class DiagonalFactor(FisherFactor):
def make_inverse_update_ops(self):
return []
+ def get_cov(self):
+ # self.get_cov() could be any shape, but it must have one entry per
+ # parameter. Flatten it into a vector.
+ cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1])
+ return array_ops.diag(cov_diag_vec)
+
+ def left_multiply(self, x, damping):
+ damped_cov = self.get_cov_var() + damping
+ if isinstance(x, tf_ops.IndexedSlices):
+ return utils.matmul_diag_sparse(array_ops.reshape(damped_cov, [-1]), x)
+
+ if x.shape != damped_cov.shape:
+ raise ValueError("x (%s) and cov (%s) must have same shape." %
+ (x, damped_cov))
+
+ return damped_cov * x
+
+ def right_multiply(self, x, damping):
+ raise NotImplementedError("Only left-multiply is currently supported.")
+
+ def left_multiply_inverse(self, x, damping):
+ inverse = 1. / (self.get_cov_var() + damping)
+
+ if isinstance(x, tf_ops.IndexedSlices):
+ return utils.matmul_diag_sparse(array_ops.reshape(inverse, [-1]), x)
+
+ if x.shape != inverse.shape:
+ raise ValueError("x (%s) and cov (%s) must have same shape." %
+ (x, inverse))
+
+ return inverse * x
+
+ def right_multiply_inverse(self, x, damping):
+ raise NotImplementedError("Only left-multiply is currently supported.")
+
+ def register_damped_inverse(self, damping):
+ # DiagonalFactors don't keep explicit inverses.
+ pass
+
class NaiveDiagonalFactor(DiagonalFactor):
"""FisherFactor for a diagonal approximation of any type of param's Fisher.
@@ -553,6 +715,14 @@ class NaiveDiagonalFactor(DiagonalFactor):
def __init__(self,
params_grads,
batch_size):
+ """Initializes NaiveDiagonalFactor instance.
+
+ Args:
+ params_grads: Sequence of Tensors, each with same shape as parameters this
+ FisherFactor corresponds to. For example, the gradient of the loss with
+ respect to parameters.
+ batch_size: int or 0-D Tensor. Size
+ """
self._params_grads = tuple(utils.ensure_sequence(params_grad)
for params_grad in params_grads)
self._batch_size = batch_size
@@ -567,7 +737,7 @@ class NaiveDiagonalFactor(DiagonalFactor):
def _cov_shape(self):
size = sum(param_grad.shape.num_elements()
for param_grad in self._params_grads[0])
- return (size, 1)
+ return [size, 1]
@property
def _num_sources(self):
@@ -584,6 +754,84 @@ class NaiveDiagonalFactor(DiagonalFactor):
self._batch_size, params_grads_flat.dtype))
+class EmbeddingInputKroneckerFactor(DiagonalFactor):
+ r"""FisherFactor for input to an embedding layer.
+
+ Given input_ids = [batch_size, input_size] representing indices into an
+ [vocab_size, embedding_size] embedding matrix, approximate input covariance by
+ a diagonal matrix,
+
+ Cov(input_ids, input_ids) =
+ (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2).
+
+ where n_hot() constructs an n-hot binary vector and diag() constructs a
+ diagonal matrix of size [vocab_size, vocab_size].
+ """
+
+ def __init__(self, input_ids, vocab_size, dtype=None):
+ """Instantiate EmbeddingInputKroneckerFactor.
+
+ Args:
+ input_ids: Tuple of Tensors of shape [batch_size, input_size] and dtype
+ int32. Indices into embedding matrix.
+ vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'.
+ dtype: dtype for covariance statistics. Must be a floating point type.
+ Defaults to float32.
+ """
+ self._input_ids = input_ids
+ self._vocab_size = vocab_size
+ self._cov_dtype = dtype or dtypes.float32
+
+ super(EmbeddingInputKroneckerFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_diag_embedding/" + scope_string_from_params(self._input_ids)
+
+ @property
+ def _cov_shape(self):
+ return [self._vocab_size]
+
+ @property
+ def _num_sources(self):
+ return len(self._input_ids)
+
+ @property
+ def _dtype(self):
+ return self._cov_dtype
+
+ def _compute_new_cov(self, idx=0):
+ with maybe_colocate_with(self._input_ids):
+ input_ids = self._input_ids[idx]
+ if len(input_ids.shape) > 2:
+ raise ValueError(
+ "Input to embeddings must have rank <= 2. Found rank %d." % len(
+ input_ids.shape))
+
+ batch_size = array_ops.shape(input_ids)[0]
+
+ # Transform indices into one-hot vectors.
+ #
+ # TODO(b/72714822): There must be a faster way to construct the diagonal
+ # covariance matrix! This operation is O(batch_size * vocab_size), where
+ # it should be O(batch_size * input_size).
+ flat_input_ids = array_ops.reshape(input_ids, [-1])
+ one_hots = array_ops.one_hot(flat_input_ids,
+ self._vocab_size) # [?, vocab_size]
+
+ # Take average across examples. Note that, because all entries have
+ # magnitude zero or one, there's no need to square the entries.
+ #
+ # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation
+ # within an example such as average.
+ #
+ # TODO(b/72714822): Support for partitioned embeddings.
+ new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size]
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+
+ return new_cov
+
+
class FullyConnectedDiagonalFactor(DiagonalFactor):
r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher.
@@ -623,8 +871,9 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
@property
def _cov_shape(self):
- return [self._inputs.shape[1] + self._has_bias,
- self._outputs_grads[0].shape[1]]
+ input_size = self._inputs.shape[1] + self._has_bias
+ output_size = self._outputs_grads[0].shape[1]
+ return [input_size, output_size]
@property
def _num_sources(self):
@@ -717,10 +966,11 @@ class ConvDiagonalFactor(DiagonalFactor):
# TODO(b/64144716): there is potential here for a big savings in terms
# of memory use.
- patches = extract_image_patches(
+ patches = array_ops.extract_image_patches(
self._inputs,
- ksizes=[filter_height, filter_width],
- strides=self._strides[1:-1],
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
padding=self._padding)
if self._has_bias:
@@ -864,10 +1114,11 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
- patches = extract_image_patches(
+ patches = array_ops.extract_image_patches(
self._inputs,
- ksizes=[filter_height, filter_width],
- strides=self._strides[1:-1],
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
padding=self._padding)
flatten_size = (filter_height * filter_width * in_channels)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
index ad93919149..2d8e378a93 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
@@ -24,26 +24,15 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
- "inverse_initializer",
- "covariance_initializer",
- "diagonal_covariance_initializer",
- "scope_string_from_params",
- "scope_string_from_name",
- "scalar_or_tensor_to_string",
- "FisherFactor",
- "InverseProvidingFactor",
- "FullFactor",
- "DiagonalFactor",
- "NaiveDiagonalFactor",
- "FullyConnectedDiagonalFactor",
- "FullyConnectedKroneckerFactor",
- "ConvInputKroneckerFactor",
- "ConvOutputKroneckerFactor",
- "ConvDiagonalFactor",
- "set_global_constants",
- "maybe_colocate_with",
- "compute_cov",
- "append_homog"
+ "inverse_initializer", "covariance_initializer",
+ "diagonal_covariance_initializer", "scope_string_from_params",
+ "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor",
+ "InverseProvidingFactor", "FullFactor", "DiagonalFactor",
+ "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor",
+ "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor",
+ "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor",
+ "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with",
+ "compute_cov", "append_homog"
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 8d450f04f3..ce9005b9ce 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -143,6 +143,7 @@ class LayerCollection(object):
self._loss_dict = {} # {str: LossFunction}
self._subgraph = None
self._default_generic_approximation = APPROX_FULL_NAME
+ self._default_embedding_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_multi_approximation = (
@@ -179,6 +180,17 @@ class LayerCollection(object):
return self._linked_parameters
@property
+ def default_embedding_approximation(self):
+ return self._default_embedding_approximation
+
+ def set_default_embedding_approximation(self, value):
+ if value != APPROX_KRONECKER_NAME:
+ raise ValueError(
+ "{} is not a valid approximation for embedding variables.".format(
+ value))
+ self._default_embedding_approximation = value
+
+ @property
def default_generic_approximation(self):
return self._default_generic_approximation
@@ -417,6 +429,46 @@ class LayerCollection(object):
else:
return None
+ def register_embedding(self,
+ params,
+ inputs,
+ outputs,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a fully connnected layer.
+
+ Args:
+ params: Embedding matrix of shape [vocab_size, embedding_size].
+ inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
+ into embedding matrix.
+ outputs: Tensor of shape [batch_size, output_size]. Outputs
+ produced by layer.
+ approx: str. Must be "kron".
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ if approx is None:
+ approx = self._get_linked_approx(params)
+ if approx is None:
+ approx = self.default_embedding_approximation
+
+ if approx != APPROX_KRONECKER_NAME:
+ raise ValueError("Bad value {} for approx.".format(approx))
+
+ if isinstance(params, (tuple, list)):
+ raise ValueError("Bias not supported.")
+
+ vocab_size = int(params.shape[0])
+ block = self.register_block(
+ params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse)
+ block.register_additional_minibatch(inputs, outputs)
+
def register_fully_connected(self,
params,
inputs,
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index e89508fa46..f5bd97cb4e 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -144,7 +144,9 @@ def layer_params_to_mat2d(vector):
[-1, w_part.shape.as_list()[-1]])
return array_ops.concat(
(w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
- else:
+ elif isinstance(vector, ops.IndexedSlices):
+ return vector
+ else: # Tensor or Tensor-like.
return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
@@ -163,6 +165,11 @@ def mat2d_to_layer_params(vector_template, mat2d):
if isinstance(vector_template, (tuple, list)):
w_part, b_part = mat2d[:-1], mat2d[-1]
return array_ops.reshape(w_part, vector_template[0].shape), b_part
+ elif isinstance(vector_template, ops.IndexedSlices):
+ if not isinstance(mat2d, ops.IndexedSlices):
+ raise TypeError(
+ "If vector_template is an IndexedSlices, so should mat2d.")
+ return mat2d
else:
return array_ops.reshape(mat2d, vector_template.shape)
@@ -420,5 +427,57 @@ def batch_execute(global_step, thunks, batch_size, name=None):
return result
+def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
+ """Computes matmul(A, B) where A is sparse, B is dense.
+
+ Args:
+ A: tf.IndexedSlices with dense shape [m, n].
+ B: tf.Tensor with shape [n, k].
+ name: str. Name of op.
+
+ Returns:
+ tf.IndexedSlices resulting from matmul(A, B).
+
+ Raises:
+ ValueError: If A doesn't represent a matrix.
+ ValueError: If B is not rank-2.
+ """
+ with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
+ if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
+ raise ValueError("A must represent a matrix. Found: %s." % A)
+ if B.shape.ndims != 2:
+ raise ValueError("B must be a matrix.")
+ new_values = math_ops.matmul(A.values, B)
+ return ops.IndexedSlices(
+ new_values,
+ A.indices,
+ dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
+
+
+def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
+ """Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
+
+ Args:
+ A_diag: diagonal entries of matrix A of shape [m, m].
+ B: tf.IndexedSlices. Represents matrix of shape [m, n].
+ name: str. Name of op.
+
+ Returns:
+ tf.IndexedSlices resulting from matmul(A, B).
+
+ Raises:
+ ValueError: If A_diag is not rank-1.
+ ValueError: If B doesn't represent a matrix.
+ """
+ with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
+ A_diag = ops.convert_to_tensor(A_diag)
+ if A_diag.shape.ndims != 1:
+ raise ValueError("A_diag must be a rank-1 Tensor.")
+ if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
+ raise ValueError("B must represent a matrix. Found: %s." % B)
+ a = array_ops.gather(A_diag, B.indices)
+ a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
+ return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
+
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index fe8e39c212..8e424a7946 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -40,6 +40,8 @@ _allowed_symbols = [
"fwd_gradients",
"ensure_sequence",
"batch_execute",
+ "matmul_sparse_dense",
+ "matmul_diag_sparse",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)