diff options
6 files changed, 197 insertions, 20 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index f29b17169b..8b82f6e314 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -40,6 +40,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:special_math_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 93235bca53..3bae45b324 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -27,9 +27,10 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -# Damping scale for blocks corresponding to convolutional layers, where the -# damping scale is adjusted according to -# damping /= num_locations ** NORMALIZE_DAMPING_POWER +# For blocks corresponding to convolutional layers, or any type of block where +# the parameters can be thought of as being replicated in time or space, +# we want to adjust the scale of the damping by +# damping /= num_replications ** NORMALIZE_DAMPING_POWER NORMALIZE_DAMPING_POWER = 1.0 @@ -227,6 +228,70 @@ class FullyConnectedDiagonalFB(FisherBlock): return self._outputs +class ConvDiagonalFB(FisherBlock): + """FisherBlock for convolutional layers using a diagonal approx. + + Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator. + """ + # TODO(jamesmartens): add units tests for this class + + def __init__(self, layer_collection, params, inputs, outputs, strides, + padding): + """Creates a ConvDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [kernel_height, kernel_width, + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + inputs: A Tensor of shape [batch_size, height, width, in_channels]. + Input activations to this layer. + outputs: A Tensor of shape [batch_size, height, width, out_channels]. + Output pre-activations from this layer. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (1-D of Tensor length 4). + """ + self._inputs = inputs + self._outputs = outputs + self._strides = strides + self._padding = padding + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + input_shape = tuple(inputs.shape.as_list()) + self._num_locations = (input_shape[1] * input_shape[2] + // (strides[1] * strides[2])) + + super(ConvDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + if NORMALIZE_DAMPING_POWER: + damping /= self._num_locations ** NORMALIZE_DAMPING_POWER + self._damping = damping + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvDiagonalFactor, + (self._inputs, grads_list, self._filter_shape, self._strides, + self._padding, self._has_bias)) + + def multiply_inverse(self, vector): + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def multiply(self, vector): + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def tensors_to_compute_grads(self): + return self._outputs + + class KroneckerProductFB(FisherBlock): """A base class for FisherBlocks with separate input and output factors. @@ -344,11 +409,16 @@ class ConvKFCBasicFB(KroneckerProductFB): Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. - inputs: The Tensor of input activatoins to this layer. - outputs: The Tensor of output pre-activations from this layer. - strides: The stride size in this layer (1-D of length 4) - padding: The padding in this layer (1-D of length 4) + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [kernel_height, kernel_width, + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + inputs: A Tensor of shape [batch_size, height, width, in_channels]. + Input activations to this layer. + outputs: A Tensor of shape [batch_size, height, width, out_channels]. + Output pre-activations from this layer. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (1-D of Tensor length 4). """ self._inputs = inputs self._outputs = outputs @@ -360,7 +430,7 @@ class ConvKFCBasicFB(KroneckerProductFB): self._filter_shape = tuple(fltr.shape.as_list()) input_shape = tuple(inputs.shape.as_list()) - self._num_locations = (input_shape[1] * input_shape[2] / + self._num_locations = (input_shape[1] * input_shape[2] // (strides[1] * strides[2])) super(ConvKFCBasicFB, self).__init__(layer_collection) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index 4937dd07db..c6cc169b37 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -31,6 +31,7 @@ _allowed_symbols = [ 'KroneckerProductFB', 'FullyConnectedKFACBasicFB', 'ConvKFCBasicFB', + 'ConvDiagonalFB' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index a776ec0afa..3d14cf1ead 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages @@ -88,18 +89,19 @@ def _compute_cov(tensor, normalizer=None): def _append_homog(tensor): - """Appends a homogeneous coordinate to the row vectors of a 2D Tensor. + """Appends a homogeneous coordinate to the last dimension of a Tensor. Args: - tensor: A 2D Tensor. + tensor: A Tensor. Returns: A Tensor identical to the input but one larger in the last dimension. The new entries are filled with ones. """ - size = array_ops.shape(tensor)[0] - ones = array_ops.ones((size, 1), dtype=tensor.dtype) - return array_ops.concat(values=[tensor, ones], axis=1) + rank = len(tensor.shape.as_list()) + shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) + ones = array_ops.ones(shape, dtype=tensor.dtype) + return array_ops.concat([tensor, ones], axis=rank-1) def scope_string_from_params(params): @@ -162,7 +164,7 @@ class FisherFactor(object): representations. Subclasses must implement the _compute_new_cov method, and the _var_scope - and_cov_shape properties. + and _cov_shape properties. """ def __init__(self): @@ -174,10 +176,19 @@ class FisherFactor(object): @abc.abstractproperty def _cov_shape(self): + """The shape of the cov matrix.""" pass @abc.abstractproperty def _num_sources(self): + """The number of things to sum over when computing cov. + + The default make_covariance_update_op function will call _compute_new_cov + with indices ranging from 0 to _num_sources-1. The typical situation is + where the factor wants to sum the statistics it computes over multiple + backpropped "gradients" (typically passed in via "tensors" or + "outputs_grads" arguments). + """ pass @property @@ -409,6 +420,9 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): self._orig_tensors_name = scope_string_from_params((inputs,) + tuple(outputs_grads)) + # Note that we precompute the required operations on the inputs since the + # inputs don't change with the 'idx' argument to _compute_new_cov. Only + # the target entry of _outputs_grads changes with idx. if has_bias: inputs = _append_homog(inputs) self._squared_inputs = math_ops.square(inputs) @@ -428,7 +442,10 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): return len(self._outputs_grads) def _compute_new_cov(self, idx=0): - # the magic formula: + # The well-known special formula that uses the fact that the entry-wise + # square of an outer product is the outer-product of the entry-wise squares. + # The gradient is the outer product of the input and the output gradients, + # so we just square both and then take their outer-product. new_cov = math_ops.matmul( self._squared_inputs, math_ops.square(self._outputs_grads[idx]), @@ -437,6 +454,86 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): return new_cov +class ConvDiagonalFactor(DiagonalFactor): + """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" + + # TODO(jamesmartens): add units tests for this class + + def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, + has_bias=False): + """Creates a ConvDiagonalFactor object. + + Args: + inputs: Tensor of shape [batch_size, height, width, in_channels]. + Input activations to this layer. + outputs_grads: Tensor of shape [batch_size, height, width, out_channels]. + Per-example gradients to the loss with respect to the layer's output + preactivations. + filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, + out_channels). Represents shape of kernel used in this layer. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (1-D of Tensor length 4). + has_bias: Python bool. If True, the layer is assumed to have a bias + parameter in addition to its filter parameter. + """ + self._filter_shape = filter_shape + self._has_bias = has_bias + self._outputs_grads = outputs_grads + + self._orig_tensors_name = scope_string_from_name((inputs,) + + tuple(outputs_grads)) + + # Note that we precompute the required operations on the inputs since the + # inputs don't change with the 'idx' argument to _compute_new_cov. Only + # the target entry of _outputs_grads changes with idx. + filter_height, filter_width, _, _ = self._filter_shape + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=strides, + rates=[1, 1, 1, 1], + padding=padding) + + if has_bias: + patches = _append_homog(patches) + + self._patches = patches + + super(ConvDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convdiag/" + self._orig_tensors_name + + @property + def _cov_shape(self): + filter_height, filter_width, in_channels, out_channels = self._filter_shape + return [filter_height * filter_width * in_channels + self._has_bias, + out_channels] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + def _compute_new_cov(self, idx=0): + outputs_grad = self._outputs_grads[idx] + batch_size = array_ops.shape(self._patches)[0] + + new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + def _convdiag_sum_of_squares(self, patches, outputs_grad): + # This computes the sum of the squares of the per-training-case "gradients". + # It does this simply by computing a giant tensor containing all of these + # them, doing an entry-wise square, and them summing along the batch + # dimension. + case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches, + outputs_grad) + return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + + class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index 8d9ba54e6e..49a07b1598 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -39,6 +39,7 @@ _allowed_symbols = [ "FullyConnectedKroneckerFactor", "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index e5de2ca17c..1b77f5d3ba 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -247,10 +247,17 @@ class LayerCollection(object): else: raise ValueError("Bad value {} for approx.".format(approx)) - def register_conv2d(self, params, strides, padding, inputs, outputs): - self.register_block(params, - fb.ConvKFCBasicFB(self, params, inputs, outputs, - strides, padding)) + def register_conv2d(self, params, strides, padding, inputs, outputs, + approx=APPROX_KRONECKER_NAME): + + if approx == APPROX_KRONECKER_NAME: + self.register_block(params, + fb.ConvKFCBasicFB(self, params, inputs, outputs, + strides, padding)) + elif approx == APPROX_DIAGONAL_NAME: + self.register_block(params, + fb.ConvDiagonalFB(self, params, inputs, outputs, + strides, padding)) def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME): params = params if isinstance(params, (tuple, list)) else (params,) |