diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-16 07:54:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-16 07:58:33 -0700 |
commit | f4c6bd6b422c6383ac814c50aa2243442e1049cc (patch) | |
tree | ee59159c1c23f16d218910e2cc49b97d8339808e /tensorflow/contrib/kfac | |
parent | 9e62c648a84f664fe338e1dec2db0f5e89ec3147 (diff) |
- Adds support for shared embedding layers (e.g. in RNNs), and shared Conv2D layers.
- Some minor refactoring of internal structure in fisher_blocks and layer_collection
PiperOrigin-RevId: 189338874
Diffstat (limited to 'tensorflow/contrib/kfac')
4 files changed, 483 insertions, 133 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index bae6bd7a3b..ba22099340 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -135,8 +135,22 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(6), 16, approx=layer_collection.APPROX_DIAGONAL_NAME) - - self.assertEqual(9, len(lc.get_blocks())) + lc.register_fully_connected_multi( + array_ops.constant(1), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + lc.register_conv2d_multi( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), + outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) + lc.register_embedding_multi( + array_ops.constant((1,)), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + + self.assertEqual(12, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 31f4689fbf..79d0424dca 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -48,6 +48,7 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest # For blocks corresponding to convolutional layers, or any type of block where # the parameters can be thought of as being replicated in time or space, @@ -74,6 +75,86 @@ def set_global_constants(normalize_damping_power=None, pi_type=None): PI_TYPE = pi_type +def _make_partitionedtensors_inputs(inputs): + """Constructs PartitionedTensor for inputs. + + The purpose of this method is to package up the towers/minibatch dimension + of these arrays into PartitionedTensor objects. + + Args: + inputs: a 1-D list of Tensors. Index is tower/mini-batch. + + Returns: + A PartitionedTensor. + """ + return utils.PartitionedTensor(inputs) + + +def _make_partitionedtensors_grads(grads_list): + """Constructs PartitionedTensor for grads_list. + + The purpose of this method is to package up the towers/minibatch dimension + of these arrays into PartitionedTensor objects. + + Args: + grads_list: 2-D list of Tensors. First index is for source, second + index for tower. + + Returns: + Tuple of PartitionedTensors, one per source. + """ + return tuple(utils.PartitionedTensor(grads) for grads in grads_list) + + +def _make_partitionedtensors_multi_inputs(inputs): + """Constructs PartitionedTensors for inputs. + + The purpose of this method is to package up the towers/minibatch dimension + of these arrays into PartitionedTensor objects. + + This version of this function is for use with FisherBlocks that deal with + multiple uses or time-steps. One PartitionedTensor is created for each + use/time-step. The FisherBlock will be responsible for concatenating + (or doing whatever else it wants) with the resulting lists. + + Args: + inputs: a 2-D list of Tensors. First index is tower/mini-batch, second is + use/time-step. + + Returns: + A tuple of PartitionedTensor's, one per use/time-step. + """ + num_uses = len(inputs[0]) + assert all(len(input_) == num_uses for input_ in inputs) + + return tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs)) + + +def _make_partitionedtensors_multi_grads(grads_list): + """Constructs PartitionedTensors for grads_list. + + The purpose of this method is to package up the towers/minibatch dimension + of these arrays into PartitionedTensor objects. + + This version of this function is for use with FisherBlocks that deal with + multiple uses or time-steps. One PartitionedTensor is created for each + use/time-step. The FisherBlock will be responsible for concatenating + (or doing whatever else it wants) with the resulting lists. + + Args: + grads_list: 3-D list of Tensors. First index is for source, second is for + tower, third is for use/time-step. + + Returns: + 2-D tuple of PartitionedTensors. First index is for source, second is for + use/time-step. + """ + num_uses = len(grads_list[0][0]) + assert all(len(grad) == num_uses for grads in grads_list for grad in grads) + return tuple(tuple(utils.PartitionedTensor(grad) + for grad in zip(*grads)) for grads in grads_list) + + def normalize_damping(damping, num_replications): """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" if NORMALIZE_DAMPING_POWER: @@ -396,57 +477,6 @@ class InputOutputMultiMinibatch(object): def _outputs(self): return self.__outputs - def _package_minibatches(self, grads_list): - """Constructs PartitionedTensor for inputs, grads_list. - - The purpose of this method is to package up the towers/minibatch dimension - of these arrays into PartitionedTensor objects. - - Args: - grads_list: 2-D list of Tensors. First index is for source, second - index for tower. - - Returns: - inputs: PartitionedTensor. - grads_list: Tuple of PartitionedTensors, one per source. - """ - inputs = utils.PartitionedTensor(self._inputs) - grads_list = tuple(utils.PartitionedTensor(grads) for grads in grads_list) - - return inputs, grads_list - - def _package_minibatches_multi(self, grads_list): - """Constructs PartitionedTensors for inputs, grads_list. - - The purpose of this method is to package up the towers/minibatch dimension - of these arrays into PartitionedTensor objects. - - This version of this function is for use with FisherBlocks that deal with - multiple uses or time-steps. One PartitionedTensor is created for each - use/time-step. - - Args: - grads_list: 3-D tuple of Tensors. First index is for source, second - index is for tower, third is for use/time-step. - - Returns: - inputs: A tuple of PartitionedTensor's, one per use/time-step. - grads_list: 2-D tuple of PartitionedTensors. First index is for source, - second is for use/time-step. - """ - # self._inputs is a 2-D tuple. First index is tower/mini-batch, second is - # use/time-step. - inputs = self._inputs - num_uses = len(inputs[0]) - assert all(len(input_) == num_uses for input_ in inputs) - assert all(len(grad) == num_uses for grads in grads_list for grad in grads) - - inputs = tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs)) - grads_list = tuple(tuple(utils.PartitionedTensor(grad) - for grad in zip(*grads)) for grads in grads_list) - - return inputs, grads_list - class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock): """FisherBlock for fully-connected (dense) layers using a diagonal approx. @@ -485,7 +515,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock): super(FullyConnectedDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._package_minibatches(grads_list) + inputs = _make_partitionedtensors_inputs(self._inputs) + grads_list = _make_partitionedtensors_grads(grads_list) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedDiagonalFactor, @@ -598,7 +629,8 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._package_minibatches(grads_list) + inputs = _make_partitionedtensors_inputs(self._inputs) + grads_list = _make_partitionedtensors_grads(grads_list) # Infer number of locations upon which convolution is applied. self._num_locations = num_conv_locations(inputs.shape.as_list(), @@ -711,7 +743,7 @@ class KroneckerProductFB(FisherBlock): class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB): """K-FAC FisherBlock for embedding layers. - This FisherBlock is similar to EmbeddingKFACFB, except that its + This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its input factor is approximated by a diagonal matrix. In the case that each example references exactly one embedding, this approximation is exact. @@ -740,17 +772,78 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - inputs, grads_list = self._package_minibatches(grads_list) + inputs = _make_partitionedtensors_inputs(self._inputs) + grads_list = _make_partitionedtensors_grads(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.EmbeddingInputKroneckerFactor, # + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # - (grads_list,)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) self._setup_damping(damping) +class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): + """K-FAC FisherBlock for embedding layers used multiple times in the graph. + + Similar to EmbeddingKFACFB except that this version supports multiple uses + of the parameter within a single model. These uses could correspond to + "time-steps", but they don't have to. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size): + """Creates a EmbeddingKFACMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACMultiIndepFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of list of Tensors. grads_list[i][j][k] is the + gradient of the loss with respect to 'outputs' from source 'i', + tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape + [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs = self._inputs + self._num_uses = num_uses = len(inputs[0]) + + # Check that all mini-batches/towers have the same number of uses + assert all(len(input_) == num_uses for input_ in inputs) + # Do the same for grads_list + assert all(len(grad) == num_uses for grad in grads for grads in grads_list) + # Merge uses and towers/minibatches dimensions together so we can handle + # it using a non-multi factor. + inputs = nest.flatten(inputs) + + # Note that we call the multi version of make_partitionedtensors only for + # grads_list here. + inputs = _make_partitionedtensors_inputs(inputs) + grads_list = _make_partitionedtensors_multi_grads(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + self._setup_damping(damping, normalization=num_uses) + + @property + def _renorm_coeff(self): + return self._num_uses + + class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. @@ -781,13 +874,14 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - inputs, grads_list = self._package_minibatches(grads_list) + inputs = _make_partitionedtensors_inputs(self._inputs) + grads_list = _make_partitionedtensors_grads(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) self._setup_damping(damping) @@ -858,12 +952,13 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._package_minibatches(grads_list) - # Infer number of locations upon which convolution is applied. self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(), self._strides) + inputs = _make_partitionedtensors_inputs(self._inputs) + grads_list = _make_partitionedtensors_grads(grads_list) + self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, (inputs, self._filter_shape, self._padding, self._strides, @@ -1139,6 +1234,10 @@ def num_conv_locations(input_shape, strides): class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters. + + This class implements the "independence across time" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb """ def __init__(self, layer_collection, has_bias=False): @@ -1156,7 +1255,8 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): def instantiate_factors(self, grads_list, damping): self._num_uses = float(len(self._inputs[0])) - inputs, grads_list = self._package_minibatches_multi(grads_list) + inputs = _make_partitionedtensors_multi_inputs(self._inputs) + grads_list = _make_partitionedtensors_multi_grads(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, @@ -1175,6 +1275,92 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): return self._outputs +class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Similar to ConvKFCBasicFB except that this version supports multiple + uses/time-steps via a standard independence approximation. Similar to the + "independence across time" used in FullyConnectedMultiIndepFB but generalized + in the obvious way to conv layers. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None): + """Creates a ConvKFCBasicMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicMultiIndepFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + # Infer number of locations upon which convolution is applied. + self._num_locations = num_locations = num_conv_locations( + self._inputs[0][0].shape.as_list(), self._strides) + + # The first index is tower/minibatch, the second is use/time-step + inputs = self._inputs + self._num_uses = num_uses = len(inputs[0]) + + # Check that all mini-batches/towers have the same number of uses + assert all(len(input_) == num_uses for input_ in inputs) + assert all(len(grad) == num_uses for grads in grads_list for grad in grads) + + # Fold uses/time-step and towers/minibatches dimensions together + inputs = nest.flatten(inputs) + # And do the same for grads_list + grads_list = tuple(nest.flatten(grads) for grads in grads_list) + + inputs = _make_partitionedtensors_inputs(inputs) + grads_list = _make_partitionedtensors_grads(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization=(num_locations * num_uses)) + + @property + def _renorm_coeff(self): + return self._num_locations * self._num_uses + + class SeriesFBApproximation(enum.IntEnum): """See FullyConnectedSeriesFB.__init__ for description and usage.""" option1 = 1 @@ -1184,7 +1370,8 @@ class SeriesFBApproximation(enum.IntEnum): class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): """FisherBlock for fully-connected layers that share parameters across time. - See the following preprint for details: + This class implements the "Option 1" and "Option 2" approximation from the + following paper: https://openreview.net/pdf?id=HyMTkQZAb See the end of the appendix of the paper for a pseudo-code of the @@ -1218,7 +1405,10 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): def instantiate_factors(self, grads_list, damping): self._num_timesteps = len(self._inputs[0]) - inputs, grads_list = self._package_minibatches_multi(grads_list) + assert len(grads_list[0][0]) == self._num_timesteps + + inputs = _make_partitionedtensors_multi_inputs(self._inputs) + grads_list = _make_partitionedtensors_multi_grads(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias)) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 4eb5e4c092..00eae8b399 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -60,6 +60,10 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +_EMBEDDING_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB +} + APPROX_KRONECKER_INDEP_NAME = "kron_indep" APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" @@ -72,6 +76,14 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { option=2) } +_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB +} + +_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" @@ -169,9 +181,12 @@ class LayerCollection(object): self._default_generic_approximation = APPROX_FULL_NAME self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME - self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_conv2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( - APPROX_KRONECKER_SERIES_2_NAME) + APPROX_KRONECKER_INDEP_NAME) + self._default_conv2d_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME self.loss_colocation_ops = {} self._vars_to_uses = defaultdict(lambda: 0) @@ -245,14 +260,14 @@ class LayerCollection(object): @property def default_conv2d_approximation(self): - return self._default_convolution_2d_approximation + return self._default_conv2d_approximation def set_default_conv2d_approximation(self, value): if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for 2d convolutional layers.".format( value)) - self._default_convolution_2d_approximation = value + self._default_conv2d_approximation = value @property def default_fully_connected_multi_approximation(self): @@ -264,6 +279,14 @@ class LayerCollection(object): "multi layer.".format(value)) self._default_fully_connected_multi_approximation = value + @property + def default_conv2d_multi_approximation(self): + return self._default_conv2d_multi_approximation + + @property + def default_embedding_multi_approximation(self): + return self._default_embedding_multi_approximation + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -526,13 +549,24 @@ class LayerCollection(object): else: return None + def _get_block_type(self, params, approx, default, approx_to_type): + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = default + + if approx not in approx_to_type: + raise ValueError("Bad value {} for approx.".format(approx)) + + return approx_to_type[approx], approx + def register_embedding(self, params, inputs, outputs, approx=None, reuse=VARIABLE_SCOPE): - """Registers a fully connnected layer. + """Registers an embedding layer. Args: params: Embedding matrix of shape [vocab_size, embedding_size]. @@ -540,7 +574,8 @@ class LayerCollection(object): into embedding matrix. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. Must be "kron". + approx: str or None. If not None must be "kron". The Fisher + approximation to use. If None the default value is used. (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -550,20 +585,15 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_embedding_approximation - - if approx != APPROX_KRONECKER_NAME: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_approximation, + _EMBEDDING_APPROX_TO_BLOCK_TYPES) if isinstance(params, (tuple, list)): raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) block = self.register_block( - params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) + params, block_type(self, vocab_size), reuse=reuse) block.register_additional_minibatch(inputs, outputs) self._add_uses(params, 1) @@ -583,7 +613,9 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of "kron" or "diagonal". + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -593,17 +625,12 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_approximation - if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_approximation, + _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) - block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) - block = self.register_block(params, block_type(self, has_bias=has_bias), reuse=reuse) block.register_additional_minibatch(inputs, outputs) @@ -635,7 +662,9 @@ class LayerCollection(object): Output produced by layer. data_format: str or None. Format of data. dilations: List of 4 ints. Dilations along each dimension. - approx: str. One of "kron" or "diagonal". + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -646,15 +675,14 @@ class LayerCollection(object): ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_conv2d_approximation - - if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_approximation, + _CONV2D_APPROX_TO_BLOCK_TYPES) - block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] + # It feels bad to pass in configuration that has to do with the internal + # implementation. And then we can't use the same constructor for both + # anymore and are thus forced to use this ugly if-statement. + # TODO(b/74793309): Clean this up? if approx == APPROX_KRONECKER_NAME: block = self.register_block( params, @@ -680,7 +708,7 @@ class LayerCollection(object): data_format=data_format), reuse=reuse) else: - raise NotImplementedError + raise NotImplementedError(approx) block.register_additional_minibatch(inputs, outputs) @@ -712,7 +740,9 @@ class LayerCollection(object): dilation_rate: List of ints of length len(..input_spatial_size..). Dilations along spatial dimension. data_format: str or None. Format of data. - approx: str. One of "kron" or "diagonal". + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -722,6 +752,8 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? assert approx is None or approx == APPROX_KRONECKER_NAME block = self.register_block( @@ -762,7 +794,8 @@ class LayerCollection(object): rate: None or List of ints of length 2. Dilation rates in spatial dimensions. data_format: str or None. Format of data. - approx: None or str. Must be "diagonal" if non-None. + approx: str or None. If not None must "diagonal". The Fisher + approximation to use. If None the default value is used. (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -772,6 +805,8 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? assert approx is None or approx == APPROX_DIAGONAL_NAME assert data_format in [None, "NHWC"] @@ -803,7 +838,7 @@ class LayerCollection(object): reuse=VARIABLE_SCOPE): """Register a call to tf.nn.separable_conv2d(). - Note: This requires access to intermediate outputs betwee depthwise and + Note: This requires access to intermediate outputs between depthwise and pointwise convolutions. Args: @@ -824,7 +859,9 @@ class LayerCollection(object): rate: None or List of ints of length 2. Dilation rate of depthwise conv2d kernel in spatial dimensions. data_format: str or None. Format of data. - approx: None or str. Must be "kron" if non-None. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -865,7 +902,9 @@ class LayerCollection(object): Args: params: Tensor or tuple of Tensors corresponding to the parameters. batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of "full" or "diagonal". + approx: str or None. It not None, must be one of "full" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -875,16 +914,10 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + block_type, approx = self._get_block_type( + params, approx, self.default_generic_approximation, + _GENERIC_APPROX_TO_BLOCK_TYPES) - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_generic_approximation - - if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - - block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) block.register_additional_minibatch(batch_size) @@ -903,11 +936,15 @@ class LayerCollection(object): this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs - to layer. In the case of RNNs, one Tensor per time step. + to layer. The list indexes each use in the graph (which might + correspond to a "time-step" in an RNN). outputs: A list of tensors, the same length as 'inputs', each of shape - [batch_size, output_size]. Outputs produced by layer. In the case of - RNNs, one Tensor per time step. - approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + [batch_size, output_size]. Outputs produced by layer. The list indexes + each use in the graph (which might correspond to a "time-step" in an + RNN). Needs to correspond with the order used in 'inputs'. + approx: str or None. If not None, must be of "kron_indep", "kron_series_1" + or "kron_series_2". The Fisher approximation to use. If None the default + value is used. (Default: None) reuse: bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. @@ -915,28 +952,129 @@ class LayerCollection(object): Raises: ValueError: For improper value to 'approx'. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_multi_approximation - has_bias = isinstance(params, (tuple, list)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_multi_approximation, + _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) # TODO(b/70283649): something along the lines of find_canonical_output # should be added back in here (and for the other block types, arguably). - if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] - + has_bias = isinstance(params, (tuple, list)) block = self.register_block(params, block_type(self, has_bias=has_bias), reuse=reuse) block.register_additional_minibatch(inputs, outputs) + + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + + def register_conv2d_multi(self, + params, + strides, + padding, + inputs, + outputs, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers convolutional layers with shared parameters. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: A list of Tensors, each of shape [batch_size, height, width, + in_channels]. Inputs to layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). + outputs: A list of Tensors, each of shape [batch_size, height, width, + out_channels]. Output produced by layer. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + Needs to correspond with the order used in 'inputs'. + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_multi_approximation, + _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) + + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches"), + reuse=reuse) + + block.register_additional_minibatch(inputs, outputs) + + assert len(inputs) == len(outputs) self._add_uses(params, len(inputs)) # TODO(b/74108452): change the loss registration functions names to refer # to "loss functions" instead of distributions. Following naming convention # of the loss function classes themselves. + def register_embedding_multi(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers embedding layers with shared parameters. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size] and + dtype int32. Indices into embedding matrix. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + outputs: A list of Tensors, each of shape [batch_size, output_size]. + Outputs produced by layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). Needs to + correspond with the order used in 'inputs'. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_multi_approximation, + _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + + block = self.register_block( + params, block_type(self, vocab_size), reuse=reuse) + block.register_additional_minibatch(inputs, outputs) + + self._add_uses(params, len(inputs)) + def register_categorical_predictive_distribution(self, logits, seed=None, diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index af26f5e56b..c589b18193 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -659,6 +659,14 @@ class PartitionedTensor(object): def __hash__(self): return hash(tuple(self.tensors)) + def __eq__(self, other): + if not isinstance(other, PartitionedTensor): + return False + return self.tensors == other.tensors + + def __ne__(self, other): + return not self == other # pylint: disable=g-comparison-negation + def as_tensor(self, dtype=None, name=None, as_ref=False): with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): assert not as_ref |