aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-16 07:54:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-16 07:58:33 -0700
commitf4c6bd6b422c6383ac814c50aa2243442e1049cc (patch)
treeee59159c1c23f16d218910e2cc49b97d8339808e /tensorflow/contrib/kfac
parent9e62c648a84f664fe338e1dec2db0f5e89ec3147 (diff)
- Adds support for shared embedding layers (e.g. in RNNs), and shared Conv2D layers.
- Some minor refactoring of internal structure in fisher_blocks and layer_collection PiperOrigin-RevId: 189338874
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py18
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py330
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py260
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py8
4 files changed, 483 insertions, 133 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index bae6bd7a3b..ba22099340 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -135,8 +135,22 @@ class LayerCollectionTest(test.TestCase):
array_ops.constant(6),
16,
approx=layer_collection.APPROX_DIAGONAL_NAME)
-
- self.assertEqual(9, len(lc.get_blocks()))
+ lc.register_fully_connected_multi(
+ array_ops.constant(1),
+ (array_ops.constant(2), array_ops.constant(3)),
+ (array_ops.constant(4), array_ops.constant(5)))
+ lc.register_conv2d_multi(
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
+ outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
+ lc.register_embedding_multi(
+ array_ops.constant((1,)),
+ (array_ops.constant(2), array_ops.constant(3)),
+ (array_ops.constant(4), array_ops.constant(5)))
+
+ self.assertEqual(12, len(lc.get_blocks()))
def testRegisterBlocksMultipleRegistrations(self):
with ops.Graph().as_default():
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 31f4689fbf..79d0424dca 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -48,6 +48,7 @@ from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import nest
# For blocks corresponding to convolutional layers, or any type of block where
# the parameters can be thought of as being replicated in time or space,
@@ -74,6 +75,86 @@ def set_global_constants(normalize_damping_power=None, pi_type=None):
PI_TYPE = pi_type
+def _make_partitionedtensors_inputs(inputs):
+ """Constructs PartitionedTensor for inputs.
+
+ The purpose of this method is to package up the towers/minibatch dimension
+ of these arrays into PartitionedTensor objects.
+
+ Args:
+ inputs: a 1-D list of Tensors. Index is tower/mini-batch.
+
+ Returns:
+ A PartitionedTensor.
+ """
+ return utils.PartitionedTensor(inputs)
+
+
+def _make_partitionedtensors_grads(grads_list):
+ """Constructs PartitionedTensor for grads_list.
+
+ The purpose of this method is to package up the towers/minibatch dimension
+ of these arrays into PartitionedTensor objects.
+
+ Args:
+ grads_list: 2-D list of Tensors. First index is for source, second
+ index for tower.
+
+ Returns:
+ Tuple of PartitionedTensors, one per source.
+ """
+ return tuple(utils.PartitionedTensor(grads) for grads in grads_list)
+
+
+def _make_partitionedtensors_multi_inputs(inputs):
+ """Constructs PartitionedTensors for inputs.
+
+ The purpose of this method is to package up the towers/minibatch dimension
+ of these arrays into PartitionedTensor objects.
+
+ This version of this function is for use with FisherBlocks that deal with
+ multiple uses or time-steps. One PartitionedTensor is created for each
+ use/time-step. The FisherBlock will be responsible for concatenating
+ (or doing whatever else it wants) with the resulting lists.
+
+ Args:
+ inputs: a 2-D list of Tensors. First index is tower/mini-batch, second is
+ use/time-step.
+
+ Returns:
+ A tuple of PartitionedTensor's, one per use/time-step.
+ """
+ num_uses = len(inputs[0])
+ assert all(len(input_) == num_uses for input_ in inputs)
+
+ return tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
+
+
+def _make_partitionedtensors_multi_grads(grads_list):
+ """Constructs PartitionedTensors for grads_list.
+
+ The purpose of this method is to package up the towers/minibatch dimension
+ of these arrays into PartitionedTensor objects.
+
+ This version of this function is for use with FisherBlocks that deal with
+ multiple uses or time-steps. One PartitionedTensor is created for each
+ use/time-step. The FisherBlock will be responsible for concatenating
+ (or doing whatever else it wants) with the resulting lists.
+
+ Args:
+ grads_list: 3-D list of Tensors. First index is for source, second is for
+ tower, third is for use/time-step.
+
+ Returns:
+ 2-D tuple of PartitionedTensors. First index is for source, second is for
+ use/time-step.
+ """
+ num_uses = len(grads_list[0][0])
+ assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
+ return tuple(tuple(utils.PartitionedTensor(grad)
+ for grad in zip(*grads)) for grads in grads_list)
+
+
def normalize_damping(damping, num_replications):
"""Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
if NORMALIZE_DAMPING_POWER:
@@ -396,57 +477,6 @@ class InputOutputMultiMinibatch(object):
def _outputs(self):
return self.__outputs
- def _package_minibatches(self, grads_list):
- """Constructs PartitionedTensor for inputs, grads_list.
-
- The purpose of this method is to package up the towers/minibatch dimension
- of these arrays into PartitionedTensor objects.
-
- Args:
- grads_list: 2-D list of Tensors. First index is for source, second
- index for tower.
-
- Returns:
- inputs: PartitionedTensor.
- grads_list: Tuple of PartitionedTensors, one per source.
- """
- inputs = utils.PartitionedTensor(self._inputs)
- grads_list = tuple(utils.PartitionedTensor(grads) for grads in grads_list)
-
- return inputs, grads_list
-
- def _package_minibatches_multi(self, grads_list):
- """Constructs PartitionedTensors for inputs, grads_list.
-
- The purpose of this method is to package up the towers/minibatch dimension
- of these arrays into PartitionedTensor objects.
-
- This version of this function is for use with FisherBlocks that deal with
- multiple uses or time-steps. One PartitionedTensor is created for each
- use/time-step.
-
- Args:
- grads_list: 3-D tuple of Tensors. First index is for source, second
- index is for tower, third is for use/time-step.
-
- Returns:
- inputs: A tuple of PartitionedTensor's, one per use/time-step.
- grads_list: 2-D tuple of PartitionedTensors. First index is for source,
- second is for use/time-step.
- """
- # self._inputs is a 2-D tuple. First index is tower/mini-batch, second is
- # use/time-step.
- inputs = self._inputs
- num_uses = len(inputs[0])
- assert all(len(input_) == num_uses for input_ in inputs)
- assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
-
- inputs = tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
- grads_list = tuple(tuple(utils.PartitionedTensor(grad)
- for grad in zip(*grads)) for grads in grads_list)
-
- return inputs, grads_list
-
class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
@@ -485,7 +515,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._package_minibatches(grads_list)
+ inputs = _make_partitionedtensors_inputs(self._inputs)
+ grads_list = _make_partitionedtensors_grads(grads_list)
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedDiagonalFactor,
@@ -598,7 +629,8 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
super(ConvDiagonalFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._package_minibatches(grads_list)
+ inputs = _make_partitionedtensors_inputs(self._inputs)
+ grads_list = _make_partitionedtensors_grads(grads_list)
# Infer number of locations upon which convolution is applied.
self._num_locations = num_conv_locations(inputs.shape.as_list(),
@@ -711,7 +743,7 @@ class KroneckerProductFB(FisherBlock):
class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""K-FAC FisherBlock for embedding layers.
- This FisherBlock is similar to EmbeddingKFACFB, except that its
+ This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
input factor is approximated by a diagonal matrix. In the case that each
example references exactly one embedding, this approximation is exact.
@@ -740,17 +772,78 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
- inputs, grads_list = self._package_minibatches(grads_list)
+ inputs = _make_partitionedtensors_inputs(self._inputs)
+ grads_list = _make_partitionedtensors_grads(grads_list)
- self._input_factor = self._layer_collection.make_or_get_factor( #
- fisher_factors.EmbeddingInputKroneckerFactor, #
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.EmbeddingInputKroneckerFactor,
(inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor( #
- fisher_factors.FullyConnectedKroneckerFactor, #
- (grads_list,))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
self._setup_damping(damping)
+class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+ """K-FAC FisherBlock for embedding layers used multiple times in the graph.
+
+ Similar to EmbeddingKFACFB except that this version supports multiple uses
+ of the parameter within a single model. These uses could correspond to
+ "time-steps", but they don't have to.
+
+ Does not support bias parameters.
+ """
+
+ def __init__(self, layer_collection, vocab_size):
+ """Creates a EmbeddingKFACMultiIndepFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ vocab_size: int. Size of vocabulary for this embedding layer.
+ """
+ self._vocab_size = vocab_size
+
+ super(EmbeddingKFACMultiIndepFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ """Instantiate Kronecker Factors for this FisherBlock.
+
+ Args:
+ grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
+ gradient of the loss with respect to 'outputs' from source 'i',
+ tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
+ [tower_minibatch_size, output_size].
+ damping: 0-D Tensor or float. 'damping' * identity is approximately added
+ to this FisherBlock's Fisher approximation.
+ """
+ inputs = self._inputs
+ self._num_uses = num_uses = len(inputs[0])
+
+ # Check that all mini-batches/towers have the same number of uses
+ assert all(len(input_) == num_uses for input_ in inputs)
+ # Do the same for grads_list
+ assert all(len(grad) == num_uses for grad in grads for grads in grads_list)
+ # Merge uses and towers/minibatches dimensions together so we can handle
+ # it using a non-multi factor.
+ inputs = nest.flatten(inputs)
+
+ # Note that we call the multi version of make_partitionedtensors only for
+ # grads_list here.
+ inputs = _make_partitionedtensors_inputs(inputs)
+ grads_list = _make_partitionedtensors_multi_grads(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.EmbeddingInputKroneckerFactor,
+ (inputs, self._vocab_size))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedMultiKF, (grads_list,))
+ self._setup_damping(damping, normalization=num_uses)
+
+ @property
+ def _renorm_coeff(self):
+ return self._num_uses
+
+
class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""K-FAC FisherBlock for fully-connected (dense) layers.
@@ -781,13 +874,14 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
damping: 0-D Tensor or float. 'damping' * identity is approximately added
to this FisherBlock's Fisher approximation.
"""
- inputs, grads_list = self._package_minibatches(grads_list)
+ inputs = _make_partitionedtensors_inputs(self._inputs)
+ grads_list = _make_partitionedtensors_grads(grads_list)
- self._input_factor = self._layer_collection.make_or_get_factor( #
- fisher_factors.FullyConnectedKroneckerFactor, #
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedKroneckerFactor,
((inputs,), self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor( #
- fisher_factors.FullyConnectedKroneckerFactor, #
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedKroneckerFactor,
(grads_list,))
self._setup_damping(damping)
@@ -858,12 +952,13 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._package_minibatches(grads_list)
-
# Infer number of locations upon which convolution is applied.
self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
self._strides)
+ inputs = _make_partitionedtensors_inputs(self._inputs)
+ grads_list = _make_partitionedtensors_grads(grads_list)
+
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
(inputs, self._filter_shape, self._padding, self._strides,
@@ -1139,6 +1234,10 @@ def num_conv_locations(input_shape, strides):
class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""FisherBlock for fully-connected layers that share parameters.
+
+ This class implements the "independence across time" approximation from the
+ following paper:
+ https://openreview.net/pdf?id=HyMTkQZAb
"""
def __init__(self, layer_collection, has_bias=False):
@@ -1156,7 +1255,8 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
def instantiate_factors(self, grads_list, damping):
self._num_uses = float(len(self._inputs[0]))
- inputs, grads_list = self._package_minibatches_multi(grads_list)
+ inputs = _make_partitionedtensors_multi_inputs(self._inputs)
+ grads_list = _make_partitionedtensors_multi_grads(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF,
@@ -1175,6 +1275,92 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
return self._outputs
+class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+ """FisherBlock for 2D convolutional layers using the basic KFC approx.
+
+ Similar to ConvKFCBasicFB except that this version supports multiple
+ uses/time-steps via a standard independence approximation. Similar to the
+ "independence across time" used in FullyConnectedMultiIndepFB but generalized
+ in the obvious way to conv layers.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None):
+ """Creates a ConvKFCBasicMultiIndepFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: The parameters (Tensor or tuple of Tensors) of this layer. If
+ kernel alone, a Tensor of shape [..spatial_filter_shape..,
+ in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
+ containing the previous and a Tensor of shape [out_channels].
+ padding: str. Padding method.
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
+ """
+ self._padding = padding
+ self._strides = maybe_tuple(strides)
+ self._dilation_rate = maybe_tuple(dilation_rate)
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
+ self._has_bias = isinstance(params, (tuple, list))
+
+ fltr = params[0] if self._has_bias else params
+ self._filter_shape = tuple(fltr.shape.as_list())
+
+ super(ConvKFCBasicMultiIndepFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ # Infer number of locations upon which convolution is applied.
+ self._num_locations = num_locations = num_conv_locations(
+ self._inputs[0][0].shape.as_list(), self._strides)
+
+ # The first index is tower/minibatch, the second is use/time-step
+ inputs = self._inputs
+ self._num_uses = num_uses = len(inputs[0])
+
+ # Check that all mini-batches/towers have the same number of uses
+ assert all(len(input_) == num_uses for input_ in inputs)
+ assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
+
+ # Fold uses/time-step and towers/minibatches dimensions together
+ inputs = nest.flatten(inputs)
+ # And do the same for grads_list
+ grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+
+ inputs = _make_partitionedtensors_inputs(inputs)
+ grads_list = _make_partitionedtensors_grads(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvInputKroneckerFactor,
+ (inputs, self._filter_shape, self._padding, self._strides,
+ self._dilation_rate, self._data_format, self._extract_patches_fn,
+ self._has_bias))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
+
+ self._setup_damping(damping, normalization=(num_locations * num_uses))
+
+ @property
+ def _renorm_coeff(self):
+ return self._num_locations * self._num_uses
+
+
class SeriesFBApproximation(enum.IntEnum):
"""See FullyConnectedSeriesFB.__init__ for description and usage."""
option1 = 1
@@ -1184,7 +1370,8 @@ class SeriesFBApproximation(enum.IntEnum):
class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
"""FisherBlock for fully-connected layers that share parameters across time.
- See the following preprint for details:
+ This class implements the "Option 1" and "Option 2" approximation from the
+ following paper:
https://openreview.net/pdf?id=HyMTkQZAb
See the end of the appendix of the paper for a pseudo-code of the
@@ -1218,7 +1405,10 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
def instantiate_factors(self, grads_list, damping):
self._num_timesteps = len(self._inputs[0])
- inputs, grads_list = self._package_minibatches_multi(grads_list)
+ assert len(grads_list[0][0]) == self._num_timesteps
+
+ inputs = _make_partitionedtensors_multi_inputs(self._inputs)
+ grads_list = _make_partitionedtensors_multi_grads(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias))
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 4eb5e4c092..00eae8b399 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -60,6 +60,10 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = {
APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
}
+_EMBEDDING_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB
+}
+
APPROX_KRONECKER_INDEP_NAME = "kron_indep"
APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1"
APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2"
@@ -72,6 +76,14 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
option=2)
}
+_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
+}
+
+_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB
+}
+
# Possible value for 'reuse' keyword argument. Sets 'reuse' to
# tf.get_variable_scope().reuse.
VARIABLE_SCOPE = "VARIABLE_SCOPE"
@@ -169,9 +181,12 @@ class LayerCollection(object):
self._default_generic_approximation = APPROX_FULL_NAME
self._default_embedding_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
- self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
+ self._default_conv2d_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_multi_approximation = (
- APPROX_KRONECKER_SERIES_2_NAME)
+ APPROX_KRONECKER_INDEP_NAME)
+ self._default_conv2d_multi_approximation = (
+ APPROX_KRONECKER_INDEP_NAME)
+ self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME
self.loss_colocation_ops = {}
self._vars_to_uses = defaultdict(lambda: 0)
@@ -245,14 +260,14 @@ class LayerCollection(object):
@property
def default_conv2d_approximation(self):
- return self._default_convolution_2d_approximation
+ return self._default_conv2d_approximation
def set_default_conv2d_approximation(self, value):
if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
raise ValueError(
"{} is not a valid approximation for 2d convolutional layers.".format(
value))
- self._default_convolution_2d_approximation = value
+ self._default_conv2d_approximation = value
@property
def default_fully_connected_multi_approximation(self):
@@ -264,6 +279,14 @@ class LayerCollection(object):
"multi layer.".format(value))
self._default_fully_connected_multi_approximation = value
+ @property
+ def default_conv2d_multi_approximation(self):
+ return self._default_conv2d_multi_approximation
+
+ @property
+ def default_embedding_multi_approximation(self):
+ return self._default_embedding_multi_approximation
+
def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
"""Validates and registers the layer_key associated with the fisher_block.
@@ -526,13 +549,24 @@ class LayerCollection(object):
else:
return None
+ def _get_block_type(self, params, approx, default, approx_to_type):
+ if approx is None:
+ approx = self._get_linked_approx(params)
+ if approx is None:
+ approx = default
+
+ if approx not in approx_to_type:
+ raise ValueError("Bad value {} for approx.".format(approx))
+
+ return approx_to_type[approx], approx
+
def register_embedding(self,
params,
inputs,
outputs,
approx=None,
reuse=VARIABLE_SCOPE):
- """Registers a fully connnected layer.
+ """Registers an embedding layer.
Args:
params: Embedding matrix of shape [vocab_size, embedding_size].
@@ -540,7 +574,8 @@ class LayerCollection(object):
into embedding matrix.
outputs: Tensor of shape [batch_size, output_size]. Outputs
produced by layer.
- approx: str. Must be "kron".
+ approx: str or None. If not None must be "kron". The Fisher
+ approximation to use. If None the default value is used. (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -550,20 +585,15 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = self.default_embedding_approximation
-
- if approx != APPROX_KRONECKER_NAME:
- raise ValueError("Bad value {} for approx.".format(approx))
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_embedding_approximation,
+ _EMBEDDING_APPROX_TO_BLOCK_TYPES)
if isinstance(params, (tuple, list)):
raise ValueError("Bias not supported.")
-
vocab_size = int(params.shape[0])
block = self.register_block(
- params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse)
+ params, block_type(self, vocab_size), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, 1)
@@ -583,7 +613,9 @@ class LayerCollection(object):
inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
outputs: Tensor of shape [batch_size, output_size]. Outputs
produced by layer.
- approx: str. One of "kron" or "diagonal".
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -593,17 +625,12 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = self.default_fully_connected_approximation
- if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
- raise ValueError("Bad value {} for approx.".format(approx))
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_fully_connected_approximation,
+ _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES)
- block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx]
has_bias = isinstance(params, (tuple, list))
-
block = self.register_block(params, block_type(self, has_bias=has_bias),
reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
@@ -635,7 +662,9 @@ class LayerCollection(object):
Output produced by layer.
data_format: str or None. Format of data.
dilations: List of 4 ints. Dilations along each dimension.
- approx: str. One of "kron" or "diagonal".
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -646,15 +675,14 @@ class LayerCollection(object):
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = self.default_conv2d_approximation
-
- if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES:
- raise ValueError("Bad value {} for approx.".format(approx))
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_conv2d_approximation,
+ _CONV2D_APPROX_TO_BLOCK_TYPES)
- block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx]
+ # It feels bad to pass in configuration that has to do with the internal
+ # implementation. And then we can't use the same constructor for both
+ # anymore and are thus forced to use this ugly if-statement.
+ # TODO(b/74793309): Clean this up?
if approx == APPROX_KRONECKER_NAME:
block = self.register_block(
params,
@@ -680,7 +708,7 @@ class LayerCollection(object):
data_format=data_format),
reuse=reuse)
else:
- raise NotImplementedError
+ raise NotImplementedError(approx)
block.register_additional_minibatch(inputs, outputs)
@@ -712,7 +740,9 @@ class LayerCollection(object):
dilation_rate: List of ints of length len(..input_spatial_size..).
Dilations along spatial dimension.
data_format: str or None. Format of data.
- approx: str. One of "kron" or "diagonal".
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -722,6 +752,8 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
+ # TODO(b/74793309): Have this use _get_block_type like the other
+ # registration functions?
assert approx is None or approx == APPROX_KRONECKER_NAME
block = self.register_block(
@@ -762,7 +794,8 @@ class LayerCollection(object):
rate: None or List of ints of length 2. Dilation rates in spatial
dimensions.
data_format: str or None. Format of data.
- approx: None or str. Must be "diagonal" if non-None.
+ approx: str or None. If not None must "diagonal". The Fisher
+ approximation to use. If None the default value is used. (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -772,6 +805,8 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
+ # TODO(b/74793309): Have this use _get_block_type like the other
+ # registration functions?
assert approx is None or approx == APPROX_DIAGONAL_NAME
assert data_format in [None, "NHWC"]
@@ -803,7 +838,7 @@ class LayerCollection(object):
reuse=VARIABLE_SCOPE):
"""Register a call to tf.nn.separable_conv2d().
- Note: This requires access to intermediate outputs betwee depthwise and
+ Note: This requires access to intermediate outputs between depthwise and
pointwise convolutions.
Args:
@@ -824,7 +859,9 @@ class LayerCollection(object):
rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
kernel in spatial dimensions.
data_format: str or None. Format of data.
- approx: None or str. Must be "kron" if non-None.
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -865,7 +902,9 @@ class LayerCollection(object):
Args:
params: Tensor or tuple of Tensors corresponding to the parameters.
batch_size: 0-D Tensor. Size of the minibatch.
- approx: str. One of "full" or "diagonal".
+ approx: str or None. It not None, must be one of "full" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -875,16 +914,10 @@ class LayerCollection(object):
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_generic_approximation,
+ _GENERIC_APPROX_TO_BLOCK_TYPES)
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = self.default_generic_approximation
-
- if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES:
- raise ValueError("Bad value {} for approx.".format(approx))
-
- block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx]
block = self.register_block(params, block_type(self, params), reuse=reuse)
block.register_additional_minibatch(batch_size)
@@ -903,11 +936,15 @@ class LayerCollection(object):
this layer. Weight matrix should have shape [input_size, output_size].
Bias should have shape [output_size].
inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs
- to layer. In the case of RNNs, one Tensor per time step.
+ to layer. The list indexes each use in the graph (which might
+ correspond to a "time-step" in an RNN).
outputs: A list of tensors, the same length as 'inputs', each of shape
- [batch_size, output_size]. Outputs produced by layer. In the case of
- RNNs, one Tensor per time step.
- approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2".
+ [batch_size, output_size]. Outputs produced by layer. The list indexes
+ each use in the graph (which might correspond to a "time-step" in an
+ RNN). Needs to correspond with the order used in 'inputs'.
+ approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
+ or "kron_series_2". The Fisher approximation to use. If None the default
+ value is used. (Default: None)
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse.
@@ -915,28 +952,129 @@ class LayerCollection(object):
Raises:
ValueError: For improper value to 'approx'.
"""
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = self.default_fully_connected_multi_approximation
- has_bias = isinstance(params, (tuple, list))
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_fully_connected_multi_approximation,
+ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES)
# TODO(b/70283649): something along the lines of find_canonical_output
# should be added back in here (and for the other block types, arguably).
- if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES:
- raise ValueError("Bad value {} for approx.".format(approx))
- block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx]
-
+ has_bias = isinstance(params, (tuple, list))
block = self.register_block(params, block_type(self, has_bias=has_bias),
reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
+
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+
+ def register_conv2d_multi(self,
+ params,
+ strides,
+ padding,
+ inputs,
+ outputs,
+ data_format=None,
+ dilations=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers convolutional layers with shared parameters.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [kernel_height,
+ kernel_width, in_channels, out_channels]. Bias should have shape
+ [out_channels].
+ strides: 1-D Tensor of length 4. Strides for convolution kernel.
+ padding: string. see tf.nn.conv2d for valid values.
+ inputs: A list of Tensors, each of shape [batch_size, height, width,
+ in_channels]. Inputs to layer. The list indexes each use in the graph
+ (which might correspond to a "time-step" in an RNN).
+ outputs: A list of Tensors, each of shape [batch_size, height, width,
+ out_channels]. Output produced by layer. The list indexes each use
+ in the graph (which might correspond to a "time-step" in an RNN).
+ Needs to correspond with the order used in 'inputs'.
+ data_format: str or None. Format of data.
+ dilations: List of 4 ints. Dilations along each dimension.
+ approx: str or None. If not None must by "kron_indep". The Fisher
+ approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_conv2d_multi_approximation,
+ _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES)
+
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ data_format=data_format,
+ dilation_rate=dilations,
+ extract_patches_fn="extract_image_patches"),
+ reuse=reuse)
+
+ block.register_additional_minibatch(inputs, outputs)
+
+ assert len(inputs) == len(outputs)
self._add_uses(params, len(inputs))
# TODO(b/74108452): change the loss registration functions names to refer
# to "loss functions" instead of distributions. Following naming convention
# of the loss function classes themselves.
+ def register_embedding_multi(self,
+ params,
+ inputs,
+ outputs,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers embedding layers with shared parameters.
+
+ Args:
+ params: Embedding matrix of shape [vocab_size, embedding_size].
+ inputs: A list of Tensors, each of shape [batch_size, input_size] and
+ dtype int32. Indices into embedding matrix. The list indexes each use
+ in the graph (which might correspond to a "time-step" in an RNN).
+ outputs: A list of Tensors, each of shape [batch_size, output_size].
+ Outputs produced by layer. The list indexes each use in the graph
+ (which might correspond to a "time-step" in an RNN). Needs to
+ correspond with the order used in 'inputs'.
+ approx: str or None. If not None must by "kron_indep". The Fisher
+ approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, reuse an existing FisherBlock. If False,
+ create a new FisherBlock. If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse.
+
+ Raises:
+ ValueError: For improper value to 'approx'.
+ KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_embedding_multi_approximation,
+ _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES)
+
+ if isinstance(params, (tuple, list)):
+ raise ValueError("Bias not supported.")
+ vocab_size = int(params.shape[0])
+
+ block = self.register_block(
+ params, block_type(self, vocab_size), reuse=reuse)
+ block.register_additional_minibatch(inputs, outputs)
+
+ self._add_uses(params, len(inputs))
+
def register_categorical_predictive_distribution(self,
logits,
seed=None,
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index af26f5e56b..c589b18193 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -659,6 +659,14 @@ class PartitionedTensor(object):
def __hash__(self):
return hash(tuple(self.tensors))
+ def __eq__(self, other):
+ if not isinstance(other, PartitionedTensor):
+ return False
+ return self.tensors == other.tensors
+
+ def __ne__(self, other):
+ return not self == other # pylint: disable=g-comparison-negation
+
def as_tensor(self, dtype=None, name=None, as_ref=False):
with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
assert not as_ref