diff options
Diffstat (limited to 'tensorflow/contrib/kfac/python/ops/fisher_blocks.py')
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/fisher_blocks.py | 1752 |
1 files changed, 0 insertions, 1752 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py deleted file mode 100644 index 9fa6eb7dcd..0000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ /dev/null @@ -1,1752 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherBlock definitions. - -This library contains classes for estimating blocks in a model's Fisher -Information matrix. Suppose one has a model that parameterizes a posterior -distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its -Fisher Information matrix is given by, - - $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ - -where, - - $$v(x, y, params) = (d / d params) log p(y | x, params)$$ - -and the expectation is taken with respect to the data's distribution for 'x' and -the model's posterior distribution for 'y', - - x ~ p(x) - y ~ p(y | x, params) - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import enum # pylint: disable=g-bad-import-order - -import numpy as np -import six - -from tensorflow.contrib.kfac.python.ops import fisher_factors -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.util import nest - -# For blocks corresponding to convolutional layers, or any type of block where -# the parameters can be thought of as being replicated in time or space, -# we want to adjust the scale of the damping by -# damping /= num_replications ** NORMALIZE_DAMPING_POWER -NORMALIZE_DAMPING_POWER = 1.0 - -# Methods for adjusting damping for FisherBlocks. See -# compute_pi_adjusted_damping() for details. -PI_OFF_NAME = "off" -PI_TRACENORM_NAME = "tracenorm" -PI_TYPE = PI_TRACENORM_NAME - - -def set_global_constants(normalize_damping_power=None, pi_type=None): - """Sets various global constants used by the classes in this module.""" - global NORMALIZE_DAMPING_POWER - global PI_TYPE - - if normalize_damping_power is not None: - NORMALIZE_DAMPING_POWER = normalize_damping_power - - if pi_type is not None: - PI_TYPE = pi_type - - -def normalize_damping(damping, num_replications): - """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" - if NORMALIZE_DAMPING_POWER: - return damping / (num_replications ** NORMALIZE_DAMPING_POWER) - return damping - - -def compute_pi_tracenorm(left_cov, right_cov): - r"""Computes the scalar constant pi for Tikhonov regularization/damping. - - $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_cov: A LinearOperator object. The left Kronecker factor "covariance". - right_cov: A LinearOperator object. The right Kronecker factor "covariance". - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = left_cov.trace() * int(right_cov.domain_dimension) - right_norm = right_cov.trace() * int(left_cov.domain_dimension) - return math_ops.sqrt(left_norm / right_norm) - - -def compute_pi_adjusted_damping(left_cov, right_cov, damping): - - if PI_TYPE == PI_TRACENORM_NAME: - pi = compute_pi_tracenorm(left_cov, right_cov) - return (damping * pi, damping / pi) - - elif PI_TYPE == PI_OFF_NAME: - return (damping, damping) - - -class PackagedFunc(object): - """A Python thunk with a stable ID. - - Enables stable names for lambdas. - """ - - def __init__(self, func, func_id): - """Initializes PackagedFunc. - - Args: - func: a zero-arg Python function. - func_id: a hashable, function that produces a hashable, or a list/tuple - thereof. - """ - self._func = func - func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) - self._func_id = func_id - - def __call__(self): - return self._func() - - @property - def func_id(self): - """A hashable identifier for this function.""" - return tuple(elt() if callable(elt) else elt for elt in self._func_id) - - -def _package_func(func, func_id): - return PackagedFunc(func, func_id) - - -@six.add_metaclass(abc.ABCMeta) -class FisherBlock(object): - """Abstract base class for objects modeling approximate Fisher matrix blocks. - - Subclasses must implement register_matpower, multiply_matpower, - instantiate_factors, tensors_to_compute_grads, and num_registered_towers - methods. - """ - - def __init__(self, layer_collection): - self._layer_collection = layer_collection - - @abc.abstractmethod - def instantiate_factors(self, grads_list, damping): - """Creates and registers the component factors of this Fisher block. - - Args: - grads_list: A list gradients (each a Tensor or tuple of Tensors) with - respect to the tensors returned by tensors_to_compute_grads() that - are to be used to estimate the block. - damping: The damping factor (float or Tensor). - """ - pass - - @abc.abstractmethod - def register_matpower(self, exp): - """Registers a matrix power to be computed by the block. - - Args: - exp: A float representing the power to raise the block by. - """ - pass - - @abc.abstractmethod - def register_cholesky(self): - """Registers a Cholesky factor to be computed by the block.""" - pass - - @abc.abstractmethod - def register_cholesky_inverse(self): - """Registers an inverse Cholesky factor to be computed by the block.""" - pass - - def register_inverse(self): - """Registers a matrix inverse to be computed by the block.""" - self.register_matpower(-1) - - @abc.abstractmethod - def multiply_matpower(self, vector, exp): - """Multiplies the vector by the (damped) matrix-power of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - exp: A float representing the power to raise the block by before - multiplying it by the vector. - - Returns: - The vector left-multiplied by the (damped) matrix-power of the block. - """ - pass - - def multiply_inverse(self, vector): - """Multiplies the vector by the (damped) inverse of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - - Returns: - The vector left-multiplied by the (damped) inverse of the block. - """ - return self.multiply_matpower(vector, -1) - - def multiply(self, vector): - """Multiplies the vector by the (damped) block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - - Returns: - The vector left-multiplied by the (damped) block. - """ - return self.multiply_matpower(vector, 1) - - @abc.abstractmethod - def multiply_cholesky(self, vector, transpose=False): - """Multiplies the vector by the (damped) Cholesky-factor of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - transpose: Bool. If true the Cholesky factor is transposed before - multiplying the vector. (Default: False) - - Returns: - The vector left-multiplied by the (damped) Cholesky-factor of the block. - """ - pass - - @abc.abstractmethod - def multiply_cholesky_inverse(self, vector, transpose=False): - """Multiplies vector by the (damped) inverse Cholesky-factor of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - transpose: Bool. If true the Cholesky factor inverse is transposed - before multiplying the vector. (Default: False) - Returns: - Vector left-multiplied by (damped) inverse Cholesky-factor of the block. - """ - pass - - @abc.abstractmethod - def tensors_to_compute_grads(self): - """Returns the Tensor(s) with respect to which this FisherBlock needs grads. - """ - pass - - @abc.abstractproperty - def num_registered_towers(self): - """Number of towers registered for this FisherBlock. - - Typically equal to the number of towers in a multi-tower setup. - """ - pass - - -class FullFB(FisherBlock): - """FisherBlock using a full matrix estimate (no approximations). - - FullFB uses a full matrix estimate (no approximations), and should only ever - be used for very low dimensional parameters. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, layer_collection, params): - """Creates a FullFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters of this layer (Tensor or tuple of Tensors). - """ - self._batch_sizes = [] - self._params = params - - super(FullFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - self._damping_func = _package_func(lambda: damping, (damping,)) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullFactor, (grads_list, self._batch_size)) - - def register_matpower(self, exp): - self._factor.register_matpower(exp, self._damping_func) - - def register_cholesky(self): - self._factor.register_cholesky(self._damping_func) - - def register_cholesky_inverse(self): - self._factor.register_cholesky_inverse(self._damping_func) - - def _multiply_matrix(self, matrix, vector, transpose=False): - vector_flat = utils.tensors_to_column(vector) - out_flat = matrix.matmul(vector_flat, adjoint=transpose) - return utils.column_to_tensors(vector, out_flat) - - def multiply_matpower(self, vector, exp): - matrix = self._factor.get_matpower(exp, self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky(self, vector, transpose=False): - matrix = self._factor.get_cholesky(self._damping_func) - return self._multiply_matrix(matrix, vector, transpose=transpose) - - def multiply_cholesky_inverse(self, vector, transpose=False): - matrix = self._factor.get_cholesky_inverse(self._damping_func) - return self._multiply_matrix(matrix, vector, transpose=transpose) - - def full_fisher_block(self): - """Explicitly constructs the full Fisher block.""" - return self._factor.get_cov_as_linear_operator().to_dense() - - def tensors_to_compute_grads(self): - return self._params - - def register_additional_tower(self, batch_size): - """Register an additional tower. - - Args: - batch_size: The batch size, used in the covariance estimator. - """ - self._batch_sizes.append(batch_size) - - @property - def num_registered_towers(self): - return len(self._batch_sizes) - - @property - def _batch_size(self): - return math_ops.reduce_sum(self._batch_sizes) - - -@six.add_metaclass(abc.ABCMeta) -class DiagonalFB(FisherBlock): - """A base class for FisherBlocks that use diagonal approximations.""" - - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def register_cholesky(self): - # Not needed for this. Cholesky's are computed on demand in the - # diagonal case - pass - - def register_cholesky_inverse(self): - # Not needed for this. Cholesky inverses's are computed on demand in the - # diagonal case - pass - - def _multiply_matrix(self, matrix, vector): - vector_flat = utils.tensors_to_column(vector) - out_flat = matrix.matmul(vector_flat) - return utils.column_to_tensors(vector, out_flat) - - def multiply_matpower(self, vector, exp): - matrix = self._factor.get_matpower(exp, self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky(self, vector, transpose=False): - matrix = self._factor.get_cholesky(self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky_inverse(self, vector, transpose=False): - matrix = self._factor.get_cholesky_inverse(self._damping_func) - return self._multiply_matrix(matrix, vector) - - def full_fisher_block(self): - return self._factor.get_cov_as_linear_operator().to_dense() - - -class NaiveDiagonalFB(DiagonalFB): - """FisherBlock using a diagonal matrix approximation. - - This type of approximation is generically applicable but quite primitive. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, layer_collection, params): - """Creates a NaiveDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters of this layer (Tensor or tuple of Tensors). - """ - self._params = params - self._batch_sizes = [] - - super(NaiveDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - self._damping_func = _package_func(lambda: damping, (damping,)) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - - def tensors_to_compute_grads(self): - return self._params - - def register_additional_tower(self, batch_size): - """Register an additional tower. - - Args: - batch_size: The batch size, used in the covariance estimator. - """ - self._batch_sizes.append(batch_size) - - @property - def num_registered_towers(self): - return len(self._batch_sizes) - - @property - def _batch_size(self): - return math_ops.reduce_sum(self._batch_sizes) - - -class InputOutputMultiTower(object): - """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" - - def __init__(self, *args, **kwargs): - self.__inputs = [] - self.__outputs = [] - super(InputOutputMultiTower, self).__init__(*args, **kwargs) - - def _process_data(self, grads_list): - """Process data into the format used by the factors. - - This function takes inputs and grads_lists data and processes it into - one of the formats expected by the FisherFactor classes (depending on - the value of the global configuration variable TOWER_STRATEGY). - - The initial format of self._inputs is expected to be a list of Tensors - over towers. Similarly grads_lists is expected to be a list over sources - of such lists. - - If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single - tensor (represented as a PartitionedTensor object) equal to the - concatenation (across towers) of all of the elements of self._inputs. And - similarly grads_list is formatted into a tuple (over sources) of such - tensors (also represented as PartitionedTensors). - - If TOWER_STRATEGY is "separate", formatting of inputs and grads_list - remains unchanged from the initial format (although possibly converting - from lists into tuples). - - Args: - grads_list: grads_list in its initial format (see above). - - Returns: - inputs: self._inputs transformed into the appropriate format (see - above). - grads_list: grads_list transformed into the appropriate format (see - above). - - Raises: - ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". - """ - inputs = self._inputs - # inputs is a list over towers of Tensors - # grads_list is a list of list with the first index being sources and the - # second being towers. - if fisher_factors.TOWER_STRATEGY == "concat": - # Merge towers together into a PartitionedTensor. We package it in - # a singleton tuple since the factors will expect a list over towers - inputs = (utils.PartitionedTensor(inputs),) - # Do the same for grads_list but preserve leading sources dimension - grads_list = tuple((utils.PartitionedTensor(grads),) - for grads in grads_list) - elif fisher_factors.TOWER_STRATEGY == "separate": - inputs = tuple(inputs) - grads_list = tuple(grads_list) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - - return inputs, grads_list - - def tensors_to_compute_grads(self): - """Tensors to compute derivative of loss with respect to.""" - return tuple(self._outputs) - - def register_additional_tower(self, inputs, outputs): - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_towers(self): - result = len(self._inputs) - assert result == len(self._outputs) - return result - - @property - def _inputs(self): - return self.__inputs - - @property - def _outputs(self): - return self.__outputs - - -class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): - """FisherBlock for fully-connected (dense) layers using a diagonal approx. - - Estimates the Fisher Information matrix's diagonal entries for a fully - connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of - squares" estimator. - - Let 'params' be a vector parameterizing a model and 'i' an arbitrary index - into it. We are interested in Fisher(params)[i, i]. This is, - - $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ]$$ - - Consider fully connected layer in this model with (unshared) weight matrix - 'w'. For an example 'x' that produces layer inputs 'a' and output - preactivations 's', - - $$v(x, y, w) = vec( a (d loss / d s)^T )$$ - - This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding - to the layer's parameters 'w'. - """ - - def __init__(self, layer_collection, has_bias=False): - """Creates a FullyConnectedDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the component Kronecker factors have an additive bias. - (Default: False) - """ - self._has_bias = has_bias - - super(FullyConnectedDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedDiagonalFactor, - (inputs, grads_list, self._has_bias)) - - self._damping_func = _package_func(lambda: damping, (damping,)) - - -class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): - """FisherBlock for 2-D convolutional layers using a diagonal approx. - - Estimates the Fisher Information matrix's diagonal entries for a convolutional - layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" - estimator. - - Let 'params' be a vector parameterizing a model and 'i' an arbitrary index - into it. We are interested in Fisher(params)[i, i]. This is, - - $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ]$$ - - Consider a convoluational layer in this model with (unshared) filter matrix - 'w'. For an example image 'x' that produces layer inputs 'a' and output - preactivations 's', - - $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ - - where 'loc' is a single (x, y) location in an image. - - This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding - to the layer's parameters 'w'. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - data_format=None, - dilations=None): - """Creates a ConvDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [kernel_height, kernel_width, - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (e.g. "SAME"). - data_format: str or None. Format of input data. - dilations: List of 4 ints or None. Rate for dilation along all dimensions. - - Raises: - ValueError: if strides is not length-4. - ValueError: if dilations is not length-4. - ValueError: if channel is not last dimension. - """ - if len(strides) != 4: - raise ValueError("strides must contain 4 numbers.") - - if dilations is None: - dilations = [1, 1, 1, 1] - - if len(dilations) != 4: - raise ValueError("dilations must contain 4 numbers.") - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - self._strides = maybe_tuple(strides) - self._padding = padding - self._data_format = data_format - self._dilations = maybe_tuple(dilations) - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - if len(self._filter_shape) != 4: - raise ValueError( - "Convolution filter must be of shape" - " [filter_height, filter_width, in_channels, out_channels].") - - super(ConvDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvDiagonalFactor, - (inputs, grads_list, self._filter_shape, self._strides, self._padding, - self._data_format, self._dilations, self._has_bias)) - - def damping_func(): - return self._num_locations * normalize_damping(damping, - self._num_locations) - - damping_id = (self._num_locations, "mult", "normalize_damping", damping, - self._num_locations) - self._damping_func = _package_func(damping_func, damping_id) - - -class KroneckerProductFB(FisherBlock): - """A base class for blocks with separate input and output Kronecker factors. - - The Fisher block is approximated as a Kronecker product of the input and - output factors. - """ - - def _setup_damping(self, damping, normalization=None): - """Makes functions that compute the damping values for both factors.""" - def compute_damping(): - if normalization is not None: - maybe_normalized_damping = normalize_damping(damping, normalization) - else: - maybe_normalized_damping = damping - - return compute_pi_adjusted_damping( - self._input_factor.get_cov_as_linear_operator(), - self._output_factor.get_cov_as_linear_operator(), - maybe_normalized_damping**0.5) - - if normalization is not None: - damping_id = ("compute_pi_adjusted_damping", - "cov", self._input_factor.name, - "cov", self._output_factor.name, - "normalize_damping", damping, normalization, "power", 0.5) - else: - damping_id = ("compute_pi_adjusted_damping", - "cov", self._input_factor.name, - "cov", self._output_factor.name, - damping, "power", 0.5) - - self._input_damping_func = _package_func(lambda: compute_damping()[0], - damping_id + ("ref", 0)) - self._output_damping_func = _package_func(lambda: compute_damping()[1], - damping_id + ("ref", 1)) - - def register_matpower(self, exp): - self._input_factor.register_matpower(exp, self._input_damping_func) - self._output_factor.register_matpower(exp, self._output_damping_func) - - def register_cholesky(self): - self._input_factor.register_cholesky(self._input_damping_func) - self._output_factor.register_cholesky(self._output_damping_func) - - def register_cholesky_inverse(self): - self._input_factor.register_cholesky_inverse(self._input_damping_func) - self._output_factor.register_cholesky_inverse(self._output_damping_func) - - @property - def _renorm_coeff(self): - """Kronecker factor multiplier coefficient. - - If this FisherBlock is represented as 'FB = c * kron(left, right)', then - this is 'c'. - - Returns: - 0-D Tensor. - """ - return 1.0 - - def _multiply_factored_matrix(self, left_factor, right_factor, vector, - extra_scale=1.0, transpose_left=False, - transpose_right=False): - reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = right_factor.matmul_right(reshaped_vector, - adjoint=transpose_right) - reshaped_out = left_factor.matmul(reshaped_out, - adjoint=transpose_left) - if extra_scale != 1.0: - reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply_matpower(self, vector, exp): - left_factor = self._input_factor.get_matpower( - exp, self._input_damping_func) - right_factor = self._output_factor.get_matpower( - exp, self._output_damping_func) - extra_scale = float(self._renorm_coeff)**exp - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale) - - def multiply_cholesky(self, vector, transpose=False): - left_factor = self._input_factor.get_cholesky(self._input_damping_func) - right_factor = self._output_factor.get_cholesky(self._output_damping_func) - extra_scale = float(self._renorm_coeff)**0.5 - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale, - transpose_left=transpose, - transpose_right=not transpose) - - def multiply_cholesky_inverse(self, vector, transpose=False): - left_factor = self._input_factor.get_cholesky_inverse( - self._input_damping_func) - right_factor = self._output_factor.get_cholesky_inverse( - self._output_damping_func) - extra_scale = float(self._renorm_coeff)**-0.5 - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale, - transpose_left=transpose, - transpose_right=not transpose) - - def full_fisher_block(self): - """Explicitly constructs the full Fisher block. - - Used for testing purposes. (In general, the result may be very large.) - - Returns: - The full Fisher block. - """ - left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() - right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() - return self._renorm_coeff * utils.kronecker_product(left_factor, - right_factor) - - -class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): - """K-FAC FisherBlock for embedding layers. - - This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its - input factor is approximated by a diagonal matrix. In the case that each - example references exactly one embedding, this approximation is exact. - - Does not support bias parameters. - """ - - def __init__(self, layer_collection, vocab_size): - """Creates a EmbeddingKFACFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - vocab_size: int. Size of vocabulary for this embedding layer. - """ - self._vocab_size = vocab_size - - super(EmbeddingKFACFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of Tensors. grads_list[i][j] is the - gradient of the loss with respect to 'outputs' from source 'i' and - tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.EmbeddingInputKroneckerFactor, - (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) - self._setup_damping(damping) - - -class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): - """K-FAC FisherBlock for fully-connected (dense) layers. - - This uses the Kronecker-factorized approximation from the original - K-FAC paper (https://arxiv.org/abs/1503.05671) - """ - - def __init__(self, layer_collection, has_bias=False): - """Creates a FullyConnectedKFACBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the component Kronecker factors have an additive bias. - (Default: False) - """ - self._has_bias = has_bias - - super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of Tensors. grads_list[i][j] is the - gradient of the loss with respect to 'outputs' from source 'i' and - tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, - ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, - (grads_list,)) - self._setup_damping(damping) - - -class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): - r"""FisherBlock for convolutional layers using the basic KFC approx. - - Estimates the Fisher Information matrix's blog for a convolutional - layer. - - Consider a convolutional layer in this model with (unshared) filter matrix - 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', - this FisherBlock estimates, - - $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], - E[flat(ds) flat(ds)^T])$$ - - where - - $$ds = (d / ds) log p(y | x, w)$$ - #locations = number of (x, y) locations where 'w' is applied. - - where the expectation is taken over all examples and locations and flat() - concatenates an array's leading dimensions. - - See equation 23 in https://arxiv.org/abs/1602.01407 for details. - """ - - def __init__(self, - layer_collection, - params, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None): - """Creates a ConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [..spatial_filter_shape.., - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - padding: str. Padding method. - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - """ - self._padding = padding - self._strides = maybe_tuple(strides) - self._dilation_rate = maybe_tuple(dilation_rate) - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - super(ConvKFCBasicFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._padding, self._strides, - self._dilation_rate, self._data_format, self._extract_patches_fn, - self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - - self._setup_damping(damping, normalization=self._num_locations) - - @property - def _renorm_coeff(self): - return self._num_locations - - -class DepthwiseConvDiagonalFB(ConvDiagonalFB): - """FisherBlock for depthwise_conv2d(). - - Equivalent to ConvDiagonalFB applied to each input channel in isolation. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - rate=None, - data_format=None): - """Creates a DepthwiseConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: Tensor of shape [filter_height, filter_width, in_channels, - channel_multiplier]. - strides: List of 4 ints. Strides along all dimensions. - padding: str. Padding method. - rate: List of 4 ints or None. Rate for dilation along all dimensions. - data_format: str or None. Format of input data. - - Raises: - NotImplementedError: If parameters contains bias. - ValueError: If filter is not 4-D. - ValueError: If strides is not length-4. - ValueError: If rates is not length-2. - ValueError: If channels are not last dimension. - """ - if isinstance(params, (tuple, list)): - raise NotImplementedError("Bias not yet supported.") - - if params.shape.ndims != 4: - raise ValueError("Filter must be 4-D.") - - if len(strides) != 4: - raise ValueError("strides must account for 4 dimensions.") - - if rate is not None: - if len(rate) != 2: - raise ValueError("rate must only account for spatial dimensions.") - rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - super(DepthwiseConvDiagonalFB, self).__init__( - layer_collection=layer_collection, - params=params, - strides=strides, - padding=padding, - dilations=rate, - data_format=data_format) - - # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). - filter_height, filter_width, in_channels, channel_multiplier = ( - params.shape.as_list()) - self._filter_shape = (filter_height, filter_width, in_channels, - in_channels * channel_multiplier) - - def _multiply_matrix(self, matrix, vector): - conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super( - DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) - return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - - -class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): - """FisherBlock for depthwise_conv2d(). - - Equivalent to ConvKFCBasicFB applied to each input channel in isolation. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - rate=None, - data_format=None): - """Creates a DepthwiseConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: Tensor of shape [filter_height, filter_width, in_channels, - channel_multiplier]. - strides: List of 4 ints. Strides along all dimensions. - padding: str. Padding method. - rate: List of 4 ints or None. Rate for dilation along all dimensions. - data_format: str or None. Format of input data. - - Raises: - NotImplementedError: If parameters contains bias. - ValueError: If filter is not 4-D. - ValueError: If strides is not length-4. - ValueError: If rates is not length-2. - ValueError: If channels are not last dimension. - """ - if isinstance(params, (tuple, list)): - raise NotImplementedError("Bias not yet supported.") - - if params.shape.ndims != 4: - raise ValueError("Filter must be 4-D.") - - if len(strides) != 4: - raise ValueError("strides must account for 4 dimensions.") - - if rate is not None: - if len(rate) != 2: - raise ValueError("rate must only account for spatial dimensions.") - rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - super(DepthwiseConvKFCBasicFB, self).__init__( - layer_collection=layer_collection, - params=params, - padding=padding, - strides=strides, - dilation_rate=rate, - data_format=data_format, - extract_patches_fn="extract_image_patches") - - # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). - filter_height, filter_width, in_channels, channel_multiplier = ( - params.shape.as_list()) - self._filter_shape = (filter_height, filter_width, in_channels, - in_channels * channel_multiplier) - - def _multiply_factored_matrix(self, left_factor, right_factor, vector, - extra_scale=1.0, transpose_left=False, - transpose_right=False): - conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super( - DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( - left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, - transpose_left=transpose_left, transpose_right=transpose_right) - return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - - -def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin - """Converts a convolution filter for use with conv2d. - - Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's - compatible with tf.nn.conv2d(). - - Args: - filter: Tensor of shape [height, width, in_channels, channel_multiplier]. - name: None or str. Name of Op. - - Returns: - Tensor of shape [height, width, in_channels, out_channels]. - - """ - with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", - [filter]): - filter = ops.convert_to_tensor(filter) - filter_height, filter_width, in_channels, channel_multiplier = ( - filter.shape.as_list()) - - results = [] - for i in range(in_channels): - # Slice out one in_channel's filter. Insert zeros around it to force it - # to affect that channel and that channel alone. - elements = [] - if i > 0: - elements.append( - array_ops.zeros( - [filter_height, filter_width, i, channel_multiplier])) - elements.append(filter[:, :, i:(i + 1), :]) - if i + 1 < in_channels: - elements.append( - array_ops.zeros([ - filter_height, filter_width, in_channels - (i + 1), - channel_multiplier - ])) - - # Concat along in_channel. - results.append( - array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) - - # Concat along out_channel. - return array_ops.concat(results, axis=-1, name="out_channel") - - -def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin - """Converts a convolution filter for use with depthwise_conv2d. - - Transforms a filter for use with tf.nn.conv2d() to one that's - compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along - the diagonal. - - Args: - filter: Tensor of shape [height, width, in_channels, out_channels]. - name: None or str. Name of Op. - - Returns: - Tensor of shape, - [height, width, in_channels, channel_multiplier] - - Raises: - ValueError: if out_channels is not evenly divisible by in_channels. - """ - with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", - [filter]): - filter = ops.convert_to_tensor(filter) - filter_height, filter_width, in_channels, out_channels = ( - filter.shape.as_list()) - - if out_channels % in_channels != 0: - raise ValueError("out_channels must be evenly divisible by in_channels.") - channel_multiplier = out_channels // in_channels - - results = [] - filter = array_ops.reshape(filter, [ - filter_height, filter_width, in_channels, in_channels, - channel_multiplier - ]) - for i in range(in_channels): - # Slice out output corresponding to the correct filter. - filter_slice = array_ops.reshape( - filter[:, :, i, i, :], - [filter_height, filter_width, 1, channel_multiplier]) - results.append(filter_slice) - - # Concat along out_channel. - return array_ops.concat(results, axis=-2, name="in_channels") - - -def maybe_tuple(obj): - if not isinstance(obj, list): - return obj - return tuple(obj) - - -def num_conv_locations(input_shape, strides): - """Returns the number of spatial locations a 2D Conv kernel is applied to. - - Args: - input_shape: List of ints representing shape of inputs to - tf.nn.convolution(). - strides: List of ints representing strides along spatial dimensions as - passed in to tf.nn.convolution(). - - Returns: - A scalar |T| denoting the number of spatial locations for the Conv layer. - """ - spatial_input_locations = np.prod(input_shape[1:-1]) - - if strides is None: - spatial_strides_divisor = 1 - else: - spatial_strides_divisor = np.prod(strides) - - return spatial_input_locations // spatial_strides_divisor - - -class InputOutputMultiTowerMultiUse(InputOutputMultiTower): - """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" - - def __init__(self, num_uses=None, *args, **kwargs): - self._num_uses = num_uses - super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) - - def _process_data(self, grads_list): - """Process temporal/multi-use data into the format used by the factors. - - This function takes inputs and grads_lists data and processes it into - one of the formats expected by the FisherFactor classes (depending on - the value of the global configuration variable TOWER_STRATEGY). - - It accepts the data in one of two initial formats. The first possible - format is where self._inputs is a list of list of Tensors. The first index - is tower, the second is use/time-step. grads_list, meanwhile, is a list - over sources of such lists of lists. - - The second possible data format is where self._inputs is a Tensor with - uses/times-steps folded into the batch dimension. i.e. it is a Tensor - of shape [num_uses * size_batch, ...] which represents a reshape of a - Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is - a list over sources of such Tensors. - - There are two possible formats which inputs and grads_list are transformed - into. - - If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing - a single tensor (represented as a PartitionedTensor object) with all of - the data from the towers, as well as the uses/time-steps, concatenated - together. In this tensor the leading dimension is the batch and - use/time-step dimensions folded together (with 'use' being the major of - these two, so that the tensors can be thought of as reshapes of ones of - shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a - tuple over sources of such tensors. - - If TOWER_STRATEGY is "separate" the inputs are formatted into lists of - tensors over towers. Each of these tensors has a similar format to - the tensor produced by the "concat" option, except that each contains - only the data from a single tower. grads_list is similarly formatted - into a tuple over sources of such tuples. - - Args: - grads_list: grads_list in its initial format (see above). - - Returns: - inputs: self._inputs transformed into the appropriate format (see - above). - grads_list: grads_list transformed into the appropriate format (see - above). - - Raises: - ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". - ValueError: If the given/initial format of self._inputs and grads_list - isn't recognized, or doesn't agree with self._num_uses. - """ - - inputs = self._inputs - - if isinstance(inputs[0], (list, tuple)): - num_uses = len(inputs[0]) - if self._num_uses is not None and self._num_uses != num_uses: - raise ValueError("num_uses argument doesn't match length of inputs.") - else: - self._num_uses = num_uses - - # Check that all mini-batches/towers have the same number of uses - if not all(len(input_) == num_uses for input_ in inputs): - raise ValueError("Length of inputs argument is inconsistent across " - "towers.") - - if fisher_factors.TOWER_STRATEGY == "concat": - # Reverse the tower and use/time-step indices, so that use is now first, - # and towers is second - inputs = tuple(zip(*inputs)) - - # Flatten the two dimensions - inputs = nest.flatten(inputs) - - # Merge everything together into a PartitionedTensor. We package it in - # a singleton tuple since the factors will expect a list over towers - inputs = (utils.PartitionedTensor(inputs),) - - elif fisher_factors.TOWER_STRATEGY == "separate": - # Merge together the uses/time-step dimension into PartitionedTensors, - # but keep the leading dimension (towers) intact for the factors to - # process individually. - inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - else: - inputs = tuple(inputs) - - # Now we perform the analogous processing for grads_list - if isinstance(grads_list[0][0], (list, tuple)): - num_uses = len(grads_list[0][0]) - if self._num_uses is not None and self._num_uses != num_uses: - raise ValueError("num_uses argument doesn't match length of outputs, " - "or length of outputs is inconsistent with length of " - "inputs.") - else: - self._num_uses = num_uses - - if not all(len(grad) == num_uses for grads in grads_list - for grad in grads): - raise ValueError("Length of outputs argument is inconsistent across " - "towers.") - - if fisher_factors.TOWER_STRATEGY == "concat": - # Reverse the tower and use/time-step indices, so that use is now first, - # and towers is second - grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) - - # Flatten the two dimensions, leaving the leading dimension (source) - # intact - grads_list = tuple(nest.flatten(grads) for grads in grads_list) - - # Merge inner dimensions together into PartitionedTensors. We package - # them in a singleton tuple since the factors will expect a list over - # towers - grads_list = tuple((utils.PartitionedTensor(grads),) - for grads in grads_list) - - elif fisher_factors.TOWER_STRATEGY == "separate": - # Merge together the uses/time-step dimension into PartitionedTensors, - # but keep the leading dimension (towers) intact for the factors to - # process individually. - grads_list = tuple(tuple(utils.PartitionedTensor(grad) - for grad in grads) - for grads in grads_list) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - else: - grads_list = tuple(tuple(grads) for grads in grads_list) - - if self._num_uses is None: - raise ValueError("You must supply a value for the num_uses argument if " - "the number of uses cannot be inferred from inputs or " - "outputs arguments (e.g. if they are both given in the " - "single Tensor format, instead of as lists of Tensors.") - - return inputs, grads_list - - -class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for fully-connected layers that share parameters. - - This class implements the "independence across time" approximation from the - following paper: - https://openreview.net/pdf?id=HyMTkQZAb - """ - - def __init__(self, layer_collection, has_bias=False, num_uses=None): - """Creates a FullyConnectedMultiIndepFB block. - - Args: - layer_collection: LayerCollection instance. - has_bias: bool. If True, estimates Fisher with respect to a bias - parameter as well as the layer's parameters. - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with uses/time folded into the - batch dimension (instead of uses/time being a list dimension). - (Default: None) - """ - self._has_bias = has_bias - - super(FullyConnectedMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, - ((inputs,), self._num_uses, self._has_bias)) - - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - - self._setup_damping(damping, normalization=self._num_uses) - - @property - def _renorm_coeff(self): - return float(self._num_uses) - - -class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for 2D convolutional layers using the basic KFC approx. - - Similar to ConvKFCBasicFB except that this version supports multiple - uses/time-steps via a standard independence approximation. Similar to the - "independence across time" used in FullyConnectedMultiIndepFB but generalized - in the obvious way to conv layers. - """ - - def __init__(self, - layer_collection, - params, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None, - num_uses=None): - """Creates a ConvKFCBasicMultiIndepFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [..spatial_filter_shape.., - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - padding: str. Padding method. - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with uses/time folded into the - batch dimension (instead of uses/time being a list dimension). - (Default: None) - """ - self._padding = padding - self._strides = maybe_tuple(strides) - self._dilation_rate = maybe_tuple(dilation_rate) - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - super(ConvKFCBasicMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._padding, self._strides, - self._dilation_rate, self._data_format, self._extract_patches_fn, - self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - - self._setup_damping(damping, normalization= - (self._num_locations * self._num_uses)) - - @property - def _renorm_coeff(self): - return self._num_locations * self._num_uses - - -class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """K-FAC FisherBlock for embedding layers used multiple times in the graph. - - Similar to EmbeddingKFACFB except that this version supports multiple uses - of the parameter within a single model. These uses could correspond to time - steps in an RNN architecture, but they don't have to. - - Does not support bias parameters. - """ - - def __init__(self, layer_collection, vocab_size, num_uses=None): - """Creates a EmbeddingKFACMultiIndepFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - vocab_size: int. Size of vocabulary for this embedding layer. - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with time folded into the batch - dimension (instead of time being a list dimension). (Default: None) - """ - self._vocab_size = vocab_size - - super(EmbeddingKFACMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of list of Tensors. grads_list[i][j][k] is the - gradient of the loss with respect to 'outputs' from source 'i', - tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape - [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.EmbeddingInputKroneckerFactor, - (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - self._setup_damping(damping, normalization=self._num_uses) - - @property - def _renorm_coeff(self): - return float(self._num_uses) - - -class SeriesFBApproximation(enum.IntEnum): - """See FullyConnectedSeriesFB.__init__ for description and usage.""" - option1 = 1 - option2 = 2 - - -class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for fully-connected layers that share parameters across time. - - This class implements the "Option 1" and "Option 2" approximation from the - following paper: - https://openreview.net/pdf?id=HyMTkQZAb - - See the end of the appendix of the paper for a pseudo-code of the - algorithm being implemented by multiply_matpower here. Note that we are - using pre-computed versions of certain matrix-matrix products to speed - things up. This is explicitly explained wherever it is done. - """ - - def __init__(self, - layer_collection, - has_bias=False, - num_uses=None, - option=SeriesFBApproximation.option2): - """Constructs a new `FullyConnectedSeriesFB`. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the layer includes a bias parameter. - num_uses: int or None. Number of time-steps over which the layer - is used. Only required if the data is formatted with time folded into - the batch dimension (instead of time being a list dimension). - (Default: None) - option: A `SeriesFBApproximation` specifying the simplifying assumption - to be used in this block. `option1` approximates the cross-covariance - over time as a symmetric matrix, while `option2` makes - the assumption that training sequences are infinitely long. See section - 3.5 of the paper for more details. - """ - - self._has_bias = has_bias - self._option = option - - super(FullyConnectedSeriesFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - @property - def _num_timesteps(self): - return self._num_uses - - @property - def _renorm_coeff(self): - # This should no longer be used since the multiply_X functions from the base - # class have been overridden - assert False - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, - ((inputs,), self._num_uses, self._has_bias)) - self._input_factor.register_cov_dt1() - - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - self._output_factor.register_cov_dt1() - - self._setup_damping(damping, normalization=self._num_uses) - - def register_matpower(self, exp): - if exp != -1: - raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" - "multiplications.") - - if self._option == SeriesFBApproximation.option1: - self._input_factor.register_option1quants(self._input_damping_func) - self._output_factor.register_option1quants(self._output_damping_func) - elif self._option == SeriesFBApproximation.option2: - self._input_factor.register_option2quants(self._input_damping_func) - self._output_factor.register_option2quants(self._output_damping_func) - else: - raise ValueError( - "Unrecognized FullyConnectedSeriesFB approximation: {}".format( - self._option)) - - def multiply_matpower(self, vector, exp): - if exp != -1: - raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" - "multiplications.") - - # pylint: disable=invalid-name - - Z = utils.layer_params_to_mat2d(vector) - - # Derivations were done for "batch_dim==1" case so we need to convert to - # that orientation: - Z = array_ops.transpose(Z) - - if self._option == SeriesFBApproximation.option1: - - # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) - L_A, psi_A = self._input_factor.get_option1quants( - self._input_damping_func) - L_G, psi_G = self._output_factor.get_option1quants( - self._output_damping_func) - - def gamma(x): - # We are assuming that each case has the same number of time-steps. - # If this stops being the case one shouldn't simply replace this T - # with its average value. Instead, one needs to go back to the - # definition of the gamma function from the paper. - T = self._num_timesteps - return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) - - # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) - # Even though Y is Z-independent we are recomputing it from the psi's - # each since Y depends on both A and G quantities, and it is relatively - # cheap to compute. - Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) - - # \\(Z = L_G^T * Z * L_A\\) - # This is equivalent to the following computation from the original - # pseudo-code: - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - # \\(Z = U_G^T * Z * U_A\\) - Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) - - # \\(Z = Z .* Y\\) - Z *= Y - - # \\(Z = L_G * Z * L_A^T\\) - # This is equivalent to the following computation from the original - # pseudo-code: - # \\(Z = U_G * Z * U_A^T\\) - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) - - elif self._option == SeriesFBApproximation.option2: - - # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), - # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) - P_A, K_A, mu_A = self._input_factor.get_option2quants( - self._input_damping_func) - P_G, K_G, mu_G = self._output_factor.get_option2quants( - self._output_damping_func) - - # Our approach differs superficially from the pseudo-code in the paper - # in order to reduce the total number of matrix-matrix multiplies. - # In particular, the first three computations in the pseudo code are - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) - # \\(Z = E_G^T * Z * E_A\\) - # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that - # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) - # the entire computation can be written as - # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) - # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) - # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) - # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) - # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) - # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) - # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) - # This final expression is computed by the following two lines: - # \\(Z = Z - P_G * Z * P_A^T\\) - Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) - # \\(Z = K_G^T * Z * K_A\\) - Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) - - # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) - # Be careful with the outer product. We don't want to accidentally - # make it an inner-product instead. - tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A - # Prevent some numerical issues by setting any 0.0 eigs to 1.0 - tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) - Z /= tmp - - # We now perform the transpose/reverse version of the operations - # derived above, whose derivation from the original pseudo-code is - # analgous. - # \\(Z = K_G * Z * K_A^T\\) - Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) - - # \\(Z = Z - P_G^T * Z * P_A\\) - Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) - - # \\(Z = normalize (1/E[T]) * Z\\) - # Note that this normalization is done because we compute the statistics - # by averaging, not summing, over time. (And the gradient is presumably - # summed over time, not averaged, and thus their scales are different.) - Z /= math_ops.cast(self._num_timesteps, Z.dtype) - - # Convert back to the "batch_dim==0" orientation. - Z = array_ops.transpose(Z) - - return utils.mat2d_to_layer_params(vector, Z) - - # pylint: enable=invalid-name - - def multiply_cholesky(self, vector): - raise NotImplementedError("FullyConnectedSeriesFB does not support " - "Cholesky computations.") - - def multiply_cholesky_inverse(self, vector): - raise NotImplementedError("FullyConnectedSeriesFB does not support " - "Cholesky computations.") - |