aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py88
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py1
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py111
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py1
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py15
6 files changed, 197 insertions, 20 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index f29b17169b..8b82f6e314 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -40,6 +40,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:special_math_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 93235bca53..3bae45b324 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -27,9 +27,10 @@ from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
-# Damping scale for blocks corresponding to convolutional layers, where the
-# damping scale is adjusted according to
-# damping /= num_locations ** NORMALIZE_DAMPING_POWER
+# For blocks corresponding to convolutional layers, or any type of block where
+# the parameters can be thought of as being replicated in time or space,
+# we want to adjust the scale of the damping by
+# damping /= num_replications ** NORMALIZE_DAMPING_POWER
NORMALIZE_DAMPING_POWER = 1.0
@@ -227,6 +228,70 @@ class FullyConnectedDiagonalFB(FisherBlock):
return self._outputs
+class ConvDiagonalFB(FisherBlock):
+ """FisherBlock for convolutional layers using a diagonal approx.
+
+ Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator.
+ """
+ # TODO(jamesmartens): add units tests for this class
+
+ def __init__(self, layer_collection, params, inputs, outputs, strides,
+ padding):
+ """Creates a ConvDiagonalFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: The parameters (Tensor or tuple of Tensors) of this layer. If
+ kernel alone, a Tensor of shape [kernel_height, kernel_width,
+ in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
+ containing the previous and a Tensor of shape [out_channels].
+ inputs: A Tensor of shape [batch_size, height, width, in_channels].
+ Input activations to this layer.
+ outputs: A Tensor of shape [batch_size, height, width, out_channels].
+ Output pre-activations from this layer.
+ strides: The stride size in this layer (1-D Tensor of length 4).
+ padding: The padding in this layer (1-D of Tensor length 4).
+ """
+ self._inputs = inputs
+ self._outputs = outputs
+ self._strides = strides
+ self._padding = padding
+ self._has_bias = isinstance(params, (tuple, list))
+
+ fltr = params[0] if self._has_bias else params
+ self._filter_shape = tuple(fltr.shape.as_list())
+
+ input_shape = tuple(inputs.shape.as_list())
+ self._num_locations = (input_shape[1] * input_shape[2]
+ // (strides[1] * strides[2]))
+
+ super(ConvDiagonalFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ if NORMALIZE_DAMPING_POWER:
+ damping /= self._num_locations ** NORMALIZE_DAMPING_POWER
+ self._damping = damping
+
+ self._factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvDiagonalFactor,
+ (self._inputs, grads_list, self._filter_shape, self._strides,
+ self._padding, self._has_bias))
+
+ def multiply_inverse(self, vector):
+ reshaped_vect = utils.layer_params_to_mat2d(vector)
+ reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping)
+ return utils.mat2d_to_layer_params(vector, reshaped_out)
+
+ def multiply(self, vector):
+ reshaped_vect = utils.layer_params_to_mat2d(vector)
+ reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping)
+ return utils.mat2d_to_layer_params(vector, reshaped_out)
+
+ def tensors_to_compute_grads(self):
+ return self._outputs
+
+
class KroneckerProductFB(FisherBlock):
"""A base class for FisherBlocks with separate input and output factors.
@@ -344,11 +409,16 @@ class ConvKFCBasicFB(KroneckerProductFB):
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer.
- inputs: The Tensor of input activatoins to this layer.
- outputs: The Tensor of output pre-activations from this layer.
- strides: The stride size in this layer (1-D of length 4)
- padding: The padding in this layer (1-D of length 4)
+ params: The parameters (Tensor or tuple of Tensors) of this layer. If
+ kernel alone, a Tensor of shape [kernel_height, kernel_width,
+ in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
+ containing the previous and a Tensor of shape [out_channels].
+ inputs: A Tensor of shape [batch_size, height, width, in_channels].
+ Input activations to this layer.
+ outputs: A Tensor of shape [batch_size, height, width, out_channels].
+ Output pre-activations from this layer.
+ strides: The stride size in this layer (1-D Tensor of length 4).
+ padding: The padding in this layer (1-D of Tensor length 4).
"""
self._inputs = inputs
self._outputs = outputs
@@ -360,7 +430,7 @@ class ConvKFCBasicFB(KroneckerProductFB):
self._filter_shape = tuple(fltr.shape.as_list())
input_shape = tuple(inputs.shape.as_list())
- self._num_locations = (input_shape[1] * input_shape[2] /
+ self._num_locations = (input_shape[1] * input_shape[2] //
(strides[1] * strides[2]))
super(ConvKFCBasicFB, self).__init__(layer_collection)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
index 4937dd07db..c6cc169b37 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
@@ -31,6 +31,7 @@ _allowed_symbols = [
'KroneckerProductFB',
'FullyConnectedKFACBasicFB',
'ConvKFCBasicFB',
+ 'ConvDiagonalFB'
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index a776ec0afa..3d14cf1ead 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import moving_averages
@@ -88,18 +89,19 @@ def _compute_cov(tensor, normalizer=None):
def _append_homog(tensor):
- """Appends a homogeneous coordinate to the row vectors of a 2D Tensor.
+ """Appends a homogeneous coordinate to the last dimension of a Tensor.
Args:
- tensor: A 2D Tensor.
+ tensor: A Tensor.
Returns:
A Tensor identical to the input but one larger in the last dimension. The
new entries are filled with ones.
"""
- size = array_ops.shape(tensor)[0]
- ones = array_ops.ones((size, 1), dtype=tensor.dtype)
- return array_ops.concat(values=[tensor, ones], axis=1)
+ rank = len(tensor.shape.as_list())
+ shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0)
+ ones = array_ops.ones(shape, dtype=tensor.dtype)
+ return array_ops.concat([tensor, ones], axis=rank-1)
def scope_string_from_params(params):
@@ -162,7 +164,7 @@ class FisherFactor(object):
representations.
Subclasses must implement the _compute_new_cov method, and the _var_scope
- and_cov_shape properties.
+ and _cov_shape properties.
"""
def __init__(self):
@@ -174,10 +176,19 @@ class FisherFactor(object):
@abc.abstractproperty
def _cov_shape(self):
+ """The shape of the cov matrix."""
pass
@abc.abstractproperty
def _num_sources(self):
+ """The number of things to sum over when computing cov.
+
+ The default make_covariance_update_op function will call _compute_new_cov
+ with indices ranging from 0 to _num_sources-1. The typical situation is
+ where the factor wants to sum the statistics it computes over multiple
+ backpropped "gradients" (typically passed in via "tensors" or
+ "outputs_grads" arguments).
+ """
pass
@property
@@ -409,6 +420,9 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
self._orig_tensors_name = scope_string_from_params((inputs,) +
tuple(outputs_grads))
+ # Note that we precompute the required operations on the inputs since the
+ # inputs don't change with the 'idx' argument to _compute_new_cov. Only
+ # the target entry of _outputs_grads changes with idx.
if has_bias:
inputs = _append_homog(inputs)
self._squared_inputs = math_ops.square(inputs)
@@ -428,7 +442,10 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
return len(self._outputs_grads)
def _compute_new_cov(self, idx=0):
- # the magic formula:
+ # The well-known special formula that uses the fact that the entry-wise
+ # square of an outer product is the outer-product of the entry-wise squares.
+ # The gradient is the outer product of the input and the output gradients,
+ # so we just square both and then take their outer-product.
new_cov = math_ops.matmul(
self._squared_inputs,
math_ops.square(self._outputs_grads[idx]),
@@ -437,6 +454,86 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
return new_cov
+class ConvDiagonalFactor(DiagonalFactor):
+ """FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
+
+ # TODO(jamesmartens): add units tests for this class
+
+ def __init__(self, inputs, outputs_grads, filter_shape, strides, padding,
+ has_bias=False):
+ """Creates a ConvDiagonalFactor object.
+
+ Args:
+ inputs: Tensor of shape [batch_size, height, width, in_channels].
+ Input activations to this layer.
+ outputs_grads: Tensor of shape [batch_size, height, width, out_channels].
+ Per-example gradients to the loss with respect to the layer's output
+ preactivations.
+ filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
+ out_channels). Represents shape of kernel used in this layer.
+ strides: The stride size in this layer (1-D Tensor of length 4).
+ padding: The padding in this layer (1-D of Tensor length 4).
+ has_bias: Python bool. If True, the layer is assumed to have a bias
+ parameter in addition to its filter parameter.
+ """
+ self._filter_shape = filter_shape
+ self._has_bias = has_bias
+ self._outputs_grads = outputs_grads
+
+ self._orig_tensors_name = scope_string_from_name((inputs,)
+ + tuple(outputs_grads))
+
+ # Note that we precompute the required operations on the inputs since the
+ # inputs don't change with the 'idx' argument to _compute_new_cov. Only
+ # the target entry of _outputs_grads changes with idx.
+ filter_height, filter_width, _, _ = self._filter_shape
+ patches = array_ops.extract_image_patches(
+ inputs,
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=strides,
+ rates=[1, 1, 1, 1],
+ padding=padding)
+
+ if has_bias:
+ patches = _append_homog(patches)
+
+ self._patches = patches
+
+ super(ConvDiagonalFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_convdiag/" + self._orig_tensors_name
+
+ @property
+ def _cov_shape(self):
+ filter_height, filter_width, in_channels, out_channels = self._filter_shape
+ return [filter_height * filter_width * in_channels + self._has_bias,
+ out_channels]
+
+ @property
+ def _num_sources(self):
+ return len(self._outputs_grads)
+
+ def _compute_new_cov(self, idx=0):
+ outputs_grad = self._outputs_grads[idx]
+ batch_size = array_ops.shape(self._patches)[0]
+
+ new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad)
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+
+ return new_cov
+
+ def _convdiag_sum_of_squares(self, patches, outputs_grad):
+ # This computes the sum of the squares of the per-training-case "gradients".
+ # It does this simply by computing a giant tensor containing all of these
+ # them, doing an entry-wise square, and them summing along the batch
+ # dimension.
+ case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches,
+ outputs_grad)
+ return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0)
+
+
class FullyConnectedKroneckerFactor(InverseProvidingFactor):
"""Kronecker factor for the input or output side of a fully-connected layer.
"""
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
index 8d9ba54e6e..49a07b1598 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
@@ -39,6 +39,7 @@ _allowed_symbols = [
"FullyConnectedKroneckerFactor",
"ConvInputKroneckerFactor",
"ConvOutputKroneckerFactor",
+ "ConvDiagonalFactor",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index e5de2ca17c..1b77f5d3ba 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -247,10 +247,17 @@ class LayerCollection(object):
else:
raise ValueError("Bad value {} for approx.".format(approx))
- def register_conv2d(self, params, strides, padding, inputs, outputs):
- self.register_block(params,
- fb.ConvKFCBasicFB(self, params, inputs, outputs,
- strides, padding))
+ def register_conv2d(self, params, strides, padding, inputs, outputs,
+ approx=APPROX_KRONECKER_NAME):
+
+ if approx == APPROX_KRONECKER_NAME:
+ self.register_block(params,
+ fb.ConvKFCBasicFB(self, params, inputs, outputs,
+ strides, padding))
+ elif approx == APPROX_DIAGONAL_NAME:
+ self.register_block(params,
+ fb.ConvDiagonalFB(self, params, inputs, outputs,
+ strides, padding))
def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
params = params if isinstance(params, (tuple, list)) else (params,)