aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-20 03:11:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 03:15:36 -0700
commit163bb675579bbc3a115c0caac9b42891f629bfd4 (patch)
tree8b84343975f146d931b39344f22b9806fb10b4ce /tensorflow/contrib/kfac
parent28db3a7eae4986e3e662de16188cf7a03be33768 (diff)
- Added support for data to be specified in RNN classes as large tensors with time folded into the batch dimension instead of lists of tensors
- Significant refactoring of RNN classes - Fixed a bunch of issues in the LayerCollection docstrings, especially around the 'reuse' argument. PiperOrigin-RevId: 189716331
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py12
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py344
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py65
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py163
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py3
5 files changed, 314 insertions, 273 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 16f02f1199..e007f70939 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -862,8 +862,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
- tensor_list = [tensor]
- factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
+ factor = ff.FullyConnectedMultiKF((tensor,), has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
@@ -872,8 +871,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
dtype = dtypes.float64_ref
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- tensor_list = [tensor]
- factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False)
+ factor = ff.FullyConnectedMultiKF((tensor,), has_bias=False)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
@@ -883,8 +881,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- tensor_list = [tensor]
- factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True)
+ factor = ff.FullyConnectedMultiKF((tensor,), has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
@@ -895,8 +892,7 @@ class FullyConnectedMultiKFTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- tensor_list = [tensor]
- factor = ff.FullyConnectedMultiKF((tensor_list,))
+ factor = ff.FullyConnectedMultiKF((tensor,))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 79d0424dca..f517e3148f 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -106,55 +106,6 @@ def _make_partitionedtensors_grads(grads_list):
return tuple(utils.PartitionedTensor(grads) for grads in grads_list)
-def _make_partitionedtensors_multi_inputs(inputs):
- """Constructs PartitionedTensors for inputs.
-
- The purpose of this method is to package up the towers/minibatch dimension
- of these arrays into PartitionedTensor objects.
-
- This version of this function is for use with FisherBlocks that deal with
- multiple uses or time-steps. One PartitionedTensor is created for each
- use/time-step. The FisherBlock will be responsible for concatenating
- (or doing whatever else it wants) with the resulting lists.
-
- Args:
- inputs: a 2-D list of Tensors. First index is tower/mini-batch, second is
- use/time-step.
-
- Returns:
- A tuple of PartitionedTensor's, one per use/time-step.
- """
- num_uses = len(inputs[0])
- assert all(len(input_) == num_uses for input_ in inputs)
-
- return tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs))
-
-
-def _make_partitionedtensors_multi_grads(grads_list):
- """Constructs PartitionedTensors for grads_list.
-
- The purpose of this method is to package up the towers/minibatch dimension
- of these arrays into PartitionedTensor objects.
-
- This version of this function is for use with FisherBlocks that deal with
- multiple uses or time-steps. One PartitionedTensor is created for each
- use/time-step. The FisherBlock will be responsible for concatenating
- (or doing whatever else it wants) with the resulting lists.
-
- Args:
- grads_list: 3-D list of Tensors. First index is for source, second is for
- tower, third is for use/time-step.
-
- Returns:
- 2-D tuple of PartitionedTensors. First index is for source, second is for
- use/time-step.
- """
- num_uses = len(grads_list[0][0])
- assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
- return tuple(tuple(utils.PartitionedTensor(grad)
- for grad in zip(*grads)) for grads in grads_list)
-
-
def normalize_damping(damping, num_replications):
"""Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
if NORMALIZE_DAMPING_POWER:
@@ -662,7 +613,7 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
class KroneckerProductFB(FisherBlock):
- """A base class for FisherBlocks with separate input and output factors.
+ """A base class for blocks with separate input and output Kronecker factors.
The Fisher block is approximated as a Kronecker product of the input and
output factors.
@@ -783,67 +734,6 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB):
self._setup_damping(damping)
-class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
- """K-FAC FisherBlock for embedding layers used multiple times in the graph.
-
- Similar to EmbeddingKFACFB except that this version supports multiple uses
- of the parameter within a single model. These uses could correspond to
- "time-steps", but they don't have to.
-
- Does not support bias parameters.
- """
-
- def __init__(self, layer_collection, vocab_size):
- """Creates a EmbeddingKFACMultiIndepFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- vocab_size: int. Size of vocabulary for this embedding layer.
- """
- self._vocab_size = vocab_size
-
- super(EmbeddingKFACMultiIndepFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
- gradient of the loss with respect to 'outputs' from source 'i',
- tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
- [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs = self._inputs
- self._num_uses = num_uses = len(inputs[0])
-
- # Check that all mini-batches/towers have the same number of uses
- assert all(len(input_) == num_uses for input_ in inputs)
- # Do the same for grads_list
- assert all(len(grad) == num_uses for grad in grads for grads in grads_list)
- # Merge uses and towers/minibatches dimensions together so we can handle
- # it using a non-multi factor.
- inputs = nest.flatten(inputs)
-
- # Note that we call the multi version of make_partitionedtensors only for
- # grads_list here.
- inputs = _make_partitionedtensors_inputs(inputs)
- grads_list = _make_partitionedtensors_multi_grads(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.EmbeddingInputKroneckerFactor,
- (inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list,))
- self._setup_damping(damping, normalization=num_uses)
-
- @property
- def _renorm_coeff(self):
- return self._num_uses
-
-
class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
"""K-FAC FisherBlock for fully-connected (dense) layers.
@@ -1232,7 +1122,70 @@ def num_conv_locations(input_shape, strides):
return spatial_input_locations // spatial_strides_divisor
-class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+class InputOutputMultiMinibatchMultiUse(InputOutputMultiMinibatch):
+ """Adds methods for multi-use/time-step case to InputOutputMultiMinibatch."""
+
+ def __init__(self, num_uses=None, *args, **kwargs):
+ self._num_uses = num_uses
+ super(InputOutputMultiMinibatchMultiUse, self).__init__(*args, **kwargs)
+
+ def _process_data(self, grads_list):
+ """Process temporal/multi-use data into a standard format."""
+
+ inputs = self._inputs
+
+ # The first possible data format is where inputs is a list of tensors,
+ # one for each use/time-step.
+ if isinstance(inputs[0], (list, tuple)):
+ # The first index is tower/minibatch, the second is use/time-step
+ num_uses = len(inputs[0])
+ if self._num_uses is not None and self._num_uses != num_uses:
+ raise ValueError("num_uses argument doesn't match length of inputs.")
+ else:
+ self._num_uses = num_uses
+
+ # Check that all mini-batches/towers have the same number of uses
+ if not all(len(input_) == num_uses for input_ in inputs):
+ raise ValueError("Length of inputs argument is inconsistent across "
+ "mini-batches/towers.")
+ # Fold uses/time-step and towers/minibatches dimensions together
+ inputs = nest.flatten(inputs)
+
+ inputs = _make_partitionedtensors_inputs(inputs)
+ # If inputs is not a tuple then we assume that inputs is a tensor
+ # with 'uses' folded into the batch dimension. (And grads_list is a list
+ # across sources of such Tensors.) This is the native format that the
+ # factor will take as arguments.
+
+ # Now we perform the analogous processing for grads_list
+ if isinstance(grads_list[0][0], (list, tuple)):
+ num_uses = len(grads_list[0][0])
+ if self._num_uses is not None and self._num_uses != num_uses:
+ raise ValueError("num_uses argument doesn't match length of outputs, "
+ "or length of outputs is inconsistent with length of "
+ "inputs.")
+ else:
+ self._num_uses = num_uses
+
+ if not all(len(grad) == num_uses for grads in grads_list
+ for grad in grads):
+ raise ValueError("Length of outputs argument is inconsistent across "
+ "mini-batches/towers.")
+
+ grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+ grads_list = _make_partitionedtensors_grads(grads_list)
+
+ if self._num_uses is None:
+ raise ValueError("You must supply a value for the num_uses argument if "
+ "the number of uses cannot be inferred from inputs or "
+ "outputs arguments (e.g. if they are both given in the "
+ "single Tensor format, instead of as lists of Tensors.")
+
+ return inputs, grads_list
+
+
+class FullyConnectedMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+ KroneckerProductFB):
"""FisherBlock for fully-connected layers that share parameters.
This class implements the "independence across time" approximation from the
@@ -1240,42 +1193,43 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
https://openreview.net/pdf?id=HyMTkQZAb
"""
- def __init__(self, layer_collection, has_bias=False):
+ def __init__(self, layer_collection, has_bias=False, num_uses=None):
"""Creates a FullyConnectedMultiIndepFB block.
Args:
layer_collection: LayerCollection instance.
has_bias: bool. If True, estimates Fisher with respect to a bias
parameter as well as the layer's parameters.
+ num_uses: int or None. Number of uses of the layer in the model's graph.
+ Only required if the data is formatted with uses/time folded into the
+ batch dimension (instead of uses/time being a list dimension).
+ (Default: None)
"""
self._has_bias = has_bias
- super(FullyConnectedMultiIndepFB, self).__init__(layer_collection)
+ super(FullyConnectedMultiIndepFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
def instantiate_factors(self, grads_list, damping):
-
- self._num_uses = float(len(self._inputs[0]))
- inputs = _make_partitionedtensors_multi_inputs(self._inputs)
- grads_list = _make_partitionedtensors_multi_grads(grads_list)
+ inputs, grads_list = self._process_data(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.FullyConnectedMultiKF,
- ((inputs,), self._has_bias))
+ ((inputs,), self._num_uses, self._has_bias))
self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list,))
+ fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
self._setup_damping(damping, normalization=self._num_uses)
@property
def _renorm_coeff(self):
- return self._num_uses
-
- def tensors_to_compute_grads(self):
- return self._outputs
+ return float(self._num_uses)
-class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
+class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+ KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
Similar to ConvKFCBasicFB except that this version supports multiple
@@ -1291,7 +1245,8 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
strides=None,
dilation_rate=None,
data_format=None,
- extract_patches_fn=None):
+ extract_patches_fn=None,
+ num_uses=None):
"""Creates a ConvKFCBasicMultiIndepFB block.
Args:
@@ -1312,6 +1267,10 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
extract_patches_fn: str or None. Name of function that extracts image
patches. One of "extract_convolution_patches", "extract_image_patches",
"extract_pointwise_conv2d_patches".
+ num_uses: int or None. Number of uses of the layer in the model's graph.
+ Only required if the data is formatted with uses/time folded into the
+ batch dimension (instead of uses/time being a list dimension).
+ (Default: None)
"""
self._padding = padding
self._strides = maybe_tuple(strides)
@@ -1323,28 +1282,16 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
fltr = params[0] if self._has_bias else params
self._filter_shape = tuple(fltr.shape.as_list())
- super(ConvKFCBasicMultiIndepFB, self).__init__(layer_collection)
+ super(ConvKFCBasicMultiIndepFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
def instantiate_factors(self, grads_list, damping):
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_locations = num_conv_locations(
- self._inputs[0][0].shape.as_list(), self._strides)
-
- # The first index is tower/minibatch, the second is use/time-step
- inputs = self._inputs
- self._num_uses = num_uses = len(inputs[0])
-
- # Check that all mini-batches/towers have the same number of uses
- assert all(len(input_) == num_uses for input_ in inputs)
- assert all(len(grad) == num_uses for grads in grads_list for grad in grads)
-
- # Fold uses/time-step and towers/minibatches dimensions together
- inputs = nest.flatten(inputs)
- # And do the same for grads_list
- grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+ inputs, grads_list = self._process_data(grads_list)
- inputs = _make_partitionedtensors_inputs(inputs)
- grads_list = _make_partitionedtensors_grads(grads_list)
+ # Infer number of locations upon which convolution is applied.
+ self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._strides)
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
@@ -1354,20 +1301,75 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
self._output_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
- self._setup_damping(damping, normalization=(num_locations * num_uses))
+ self._setup_damping(damping, normalization=
+ (self._num_locations * self._num_uses))
@property
def _renorm_coeff(self):
return self._num_locations * self._num_uses
+class EmbeddingKFACMultiIndepFB(InputOutputMultiMinibatchMultiUse,
+ KroneckerProductFB):
+ """K-FAC FisherBlock for embedding layers used multiple times in the graph.
+
+ Similar to EmbeddingKFACFB except that this version supports multiple uses
+ of the parameter within a single model. These uses could correspond to time
+ steps in an RNN architecture, but they don't have to.
+
+ Does not support bias parameters.
+ """
+
+ def __init__(self, layer_collection, vocab_size, num_uses):
+ """Creates a EmbeddingKFACMultiIndepFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ vocab_size: int. Size of vocabulary for this embedding layer.
+ num_uses: int or None. Number of uses of the layer in the model's graph.
+ Only required if the data is formatted with time folded into the batch
+ dimension (instead of time being a list dimension). (Default: None)
+ """
+ self._vocab_size = vocab_size
+
+ super(EmbeddingKFACMultiIndepFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
+
+ def instantiate_factors(self, grads_list, damping):
+ """Instantiate Kronecker Factors for this FisherBlock.
+
+ Args:
+ grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
+ gradient of the loss with respect to 'outputs' from source 'i',
+ tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
+ [tower_minibatch_size, output_size].
+ damping: 0-D Tensor or float. 'damping' * identity is approximately added
+ to this FisherBlock's Fisher approximation.
+ """
+ inputs, grads_list = self._process_data(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.EmbeddingInputKroneckerFactor,
+ (inputs, self._vocab_size))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
+ self._setup_damping(damping, normalization=self._num_uses)
+
+ @property
+ def _renorm_coeff(self):
+ return float(self._num_uses)
+
+
class SeriesFBApproximation(enum.IntEnum):
"""See FullyConnectedSeriesFB.__init__ for description and usage."""
option1 = 1
option2 = 2
-class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
+class FullyConnectedSeriesFB(InputOutputMultiMinibatchMultiUse,
+ KroneckerProductFB):
"""FisherBlock for fully-connected layers that share parameters across time.
This class implements the "Option 1" and "Option 2" approximation from the
@@ -1383,6 +1385,7 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
def __init__(self,
layer_collection,
has_bias=False,
+ num_uses=None,
option=SeriesFBApproximation.option2):
"""Constructs a new `FullyConnectedSeriesFB`.
@@ -1390,6 +1393,10 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
has_bias: Whether the layer includes a bias parameter.
+ num_uses: int or None. Number of time-steps over which the layer
+ is used. Only required if the data is formatted with time folded into
+ the batch dimension (instead of time being a list dimension).
+ (Default: None)
option: A `SeriesFBApproximation` specifying the simplifying assumption
to be used in this block. `option1` approximates the cross-covariance
over time as a symmetric matrix, while `option2` makes
@@ -1400,39 +1407,33 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
self._has_bias = has_bias
self._option = option
- super(FullyConnectedSeriesFB, self).__init__(layer_collection)
+ super(FullyConnectedSeriesFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
- def instantiate_factors(self, grads_list, damping):
+ @property
+ def _num_timesteps(self):
+ return self._num_uses
- self._num_timesteps = len(self._inputs[0])
- assert len(grads_list[0][0]) == self._num_timesteps
+ @property
+ def _renorm_coeff(self):
+ # This should no longer be used since the multiply_X functions from the base
+ # class have been overridden
+ assert False
- inputs = _make_partitionedtensors_multi_inputs(self._inputs)
- grads_list = _make_partitionedtensors_multi_grads(grads_list)
+ def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias))
+ fisher_factors.FullyConnectedMultiKF,
+ ((inputs,), self._num_uses, self._has_bias))
self._input_factor.register_cov_dt1()
self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list,))
+ fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
self._output_factor.register_cov_dt1()
- def compute_damping():
- normalized_damping = normalize_damping(damping, self._num_timesteps)
- return compute_pi_adjusted_damping(self._input_factor.get_cov(),
- self._output_factor.get_cov(),
- normalized_damping**0.5)
-
- damping_id = ("compute_pi_adjusted_damping",
- "cov", self._input_factor.name,
- "cov", self._output_factor.name,
- "normalize_damping",
- damping, self._num_timesteps, "power", 0.5)
- self._input_damping_func = _package_func(lambda: compute_damping()[0],
- damping_id + ("ref", 0))
- self._output_damping_func = _package_func(lambda: compute_damping()[1],
- damping_id + ("ref", 1))
+ self._setup_damping(damping, normalization=self._num_uses)
def register_matpower(self, exp):
if exp != -1:
@@ -1562,6 +1563,3 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock):
return utils.mat2d_to_layer_params(vector, Z)
# pylint: enable=invalid-name
-
- def tensors_to_compute_grads(self):
- return self._outputs
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 6fc163e232..f521363536 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -35,7 +35,6 @@ from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import moving_averages
-from tensorflow.python.util import nest
# Whether to initialize covariance estimators at a zero matrix (or the identity
# matrix).
@@ -1227,27 +1226,24 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
return compute_cov(reshaped_tensor)
-class FullyConnectedMultiKF(InverseProvidingFactor):
+class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
"""Kronecker factor for a fully connected layer used multiple times."""
def __init__(self,
- tensor_lists,
+ tensors,
+ num_uses=None,
has_bias=False):
"""Constructs a new `FullyConnectedMultiKF`.
Args:
- tensor_lists: 2D array (list of lists) of Tensors of shape
- [batch_size, n]. Each of these tensors is usually a layer's inputs or
- its output's gradients. The first dimension of the array is the source,
- and the second is the use in the graph (which is sometimes a
- "time-step").
+ tensors: List of Tensors of shape, each of shape [batch_size, n]. Each of
+ these tensors is usually a layer's inputs or its output's gradients.
+ The list is over sources.
+ num_uses: int. The number of time-steps / uses.
has_bias: bool. If True, '1' is appended to each row.
"""
- self._tensor_lists = tensor_lists
- self._has_bias = has_bias
- self._num_timesteps = len(tensor_lists[0])
- self._tensors = [None] * len(tensor_lists)
+ self._num_uses = num_uses
self._cov_dt1 = None
self._make_cov_dt1 = False
@@ -1256,20 +1252,17 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
self._option1quants_registrations = set()
self._option2quants_registrations = set()
- super(FullyConnectedMultiKF, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_fc_multi_" + scope_string_from_params(
- tuple(nest.flatten(self._tensor_lists)) + (self._has_bias,))
+ super(FullyConnectedMultiKF, self).__init__(tensors=tensors,
+ has_bias=has_bias)
@property
- def _num_sources(self):
- return len(self._tensor_lists)
+ def _num_timesteps(self):
+ return self._num_uses
@property
- def _dtype(self):
- return self._tensor_lists[0][0].dtype
+ def _var_scope(self):
+ return "ff_fc_multi_" + scope_string_from_params(
+ tuple(self._tensors) + (self._num_timesteps, self._has_bias,))
def make_covariance_update_op(self, ema_decay):
@@ -1291,36 +1284,28 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
return op
- def _compute_new_cov(self, idx=0):
- # Concatenate across time/replications
- tensor = array_ops.concat(self._tensor_lists[idx], 0)
+ def _compute_new_cov_dt1(self, idx=0): # pylint: disable=missing-docstring
+ tensor = self._tensors[idx]
if self._has_bias:
+ # This appending is technically done twice (the other time is for
+ # _compute_new_cov())
tensor = append_homog(tensor)
- # We save these so they can be used by _compute_new_cov_dt1
- self._tensors[idx] = tensor
- return compute_cov(tensor)
- def _compute_new_cov_dt1(self, idx=0): # pylint: disable=missing-docstring
- tensor = self._tensors[idx]
- batch_size = array_ops.shape(self._tensor_lists[idx][0])[0]
- # Is there a more elegant way to do this computation?
+ total_len = array_ops.shape(tensor)[0]
+ batch_size = total_len // self._num_timesteps
+
tensor_present = tensor[:-batch_size, :]
tensor_future = tensor[batch_size:, :]
+
# We specify a normalizer for this computation to ensure a PSD Fisher
# block estimate. This is equivalent to padding with zeros, as was done
# in Section B.2 of the appendix.
- normalizer = self._num_timesteps * batch_size
return compute_cov(
- tensor_future, tensor_right=tensor_present, normalizer=normalizer)
-
- @property
- def _cov_shape(self):
- size = self._tensor_lists[0][0].shape[1] + self._has_bias
- return [size, size]
+ tensor_future, tensor_right=tensor_present, normalizer=total_len)
@property
def _vec_shape(self):
- size = self._tensor_lists[0][0].shape[1] + self._has_bias
+ size = self._tensors[0].shape[1] + self._has_bias
return [size]
def get_option1quants(self, damping_func):
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 00eae8b399..7727c607db 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -572,13 +572,15 @@ class LayerCollection(object):
params: Embedding matrix of shape [vocab_size, embedding_size].
inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
into embedding matrix.
- outputs: Tensor of shape [batch_size, output_size]. Outputs
+ outputs: Tensor of shape [batch_size, embedding_size]. Outputs
produced by layer.
approx: str or None. If not None must be "kron". The Fisher
approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -616,9 +618,11 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -665,9 +669,11 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -743,9 +749,11 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -796,9 +804,11 @@ class LayerCollection(object):
data_format: str or None. Format of data.
approx: str or None. If not None must "diagonal". The Fisher
approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -862,9 +872,11 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -905,9 +917,10 @@ class LayerCollection(object):
approx: str or None. It not None, must be one of "full" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'batch_size' to the total
+ mini-batch size use when estimating the Fisher block for this layer
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -924,7 +937,8 @@ class LayerCollection(object):
self._add_uses(params, float("inf"))
def register_fully_connected_multi(self, params, inputs, outputs,
- approx=None, reuse=VARIABLE_SCOPE):
+ num_uses=None, approx=None,
+ reuse=VARIABLE_SCOPE):
"""Register fully connected layers with shared parameters.
This can handle general fully-connected layers with shared parameters, but
@@ -935,19 +949,31 @@ class LayerCollection(object):
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
this layer. Weight matrix should have shape [input_size, output_size].
Bias should have shape [output_size].
- inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs
+ inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs
to layer. The list indexes each use in the graph (which might
- correspond to a "time-step" in an RNN).
- outputs: A list of tensors, the same length as 'inputs', each of shape
+ correspond to a "time-step" in an RNN). OR, can be single Tensor, of
+ shape [batch_size * num_uses, input_size], which is a reshaped version
+ of a Tensor of shape [batch_size, num_uses, input_size].
+ outputs: A list of Tensors, the same length as 'inputs', each of shape
[batch_size, output_size]. Outputs produced by layer. The list indexes
each use in the graph (which might correspond to a "time-step" in an
- RNN). Needs to correspond with the order used in 'inputs'.
+ RNN). Needs to correspond with the order used in 'inputs'. OR, can be
+ a single Tensor of shape [batch_size * num_uses, output_size], which is
+ a reshaped version of a Tensor of shape [batch_size, num_uses,
+ output_size].
+ num_uses: int or None. The number uses/time-steps in the graph where the
+ layer appears. Only needed if both inputs and outputs are given in the
+ single Tensor format. (Default: None)
approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
or "kron_series_2". The Fisher approximation to use. If None the default
value is used. (Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds inputs and outputs as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
+ word 'use' here has a completely different meaning to "use in the graph"
+ as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -960,7 +986,8 @@ class LayerCollection(object):
# should be added back in here (and for the other block types, arguably).
has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias=has_bias),
+ block = self.register_block(params, block_type(self, has_bias=has_bias,
+ num_uses=num_uses),
reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
@@ -973,6 +1000,7 @@ class LayerCollection(object):
padding,
inputs,
outputs,
+ num_uses=None,
data_format=None,
dilations=None,
approx=None,
@@ -988,19 +1016,32 @@ class LayerCollection(object):
padding: string. see tf.nn.conv2d for valid values.
inputs: A list of Tensors, each of shape [batch_size, height, width,
in_channels]. Inputs to layer. The list indexes each use in the graph
- (which might correspond to a "time-step" in an RNN).
+ (which might correspond to a "time-step" in an RNN). OR, can be single
+ Tensor, of shape [batch_size * num_uses, height, width, in_channels],
+ which is a reshaped version of a Tensor of shape [batch_size, num_uses,
+ height, width, in_channels].
outputs: A list of Tensors, each of shape [batch_size, height, width,
out_channels]. Output produced by layer. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
- Needs to correspond with the order used in 'inputs'.
+ Needs to correspond with the order used in 'inputs'. OR, can be a
+ single Tensor, of shape [batch_size*num_uses, height, width,
+ out_channels], which is a reshaped version of a Tensor of shape
+ [batch_size, num_uses, height, width, out_channels].
+ num_uses: int or None. The number uses/time-steps in the graph where the
+ layer appears. Only needed if both inputs and outputs are given in the
+ single Tensor format. (Default: None)
data_format: str or None. Format of data.
dilations: List of 4 ints. Dilations along each dimension.
approx: str or None. If not None must by "kron_indep". The Fisher
approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds inputs and outputs as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
+ word 'use' here has a completely different meaning to "use in the graph"
+ as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -1020,7 +1061,8 @@ class LayerCollection(object):
strides=strides,
data_format=data_format,
dilation_rate=dilations,
- extract_patches_fn="extract_image_patches"),
+ extract_patches_fn="extract_image_patches",
+ num_uses=num_uses),
reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
@@ -1036,6 +1078,7 @@ class LayerCollection(object):
params,
inputs,
outputs,
+ num_uses=None,
approx=None,
reuse=VARIABLE_SCOPE):
"""Registers embedding layers with shared parameters.
@@ -1045,16 +1088,29 @@ class LayerCollection(object):
inputs: A list of Tensors, each of shape [batch_size, input_size] and
dtype int32. Indices into embedding matrix. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
- outputs: A list of Tensors, each of shape [batch_size, output_size].
+ OR, can be single Tensor, of shape [batch_size * num_uses, input_size],
+ which is a reshaped version of a Tensor of shape [batch_size, num_uses,
+ input_size].
+ outputs: A list of Tensors, each of shape [batch_size, embedding_size].
Outputs produced by layer. The list indexes each use in the graph
(which might correspond to a "time-step" in an RNN). Needs to
- correspond with the order used in 'inputs'.
+ correspond with the order used in 'inputs'. OR, can be a
+ single Tensor, of shape [batch_size*num_uses, embedding_size], which
+ is a reshaped version of a Tensor of shape [batch_size, num_uses,
+ embedding_size].
+ num_uses: int or None. The number uses/time-steps in the graph where the
+ layer appears. Only needed if both inputs and outputs are given in the
+ single Tensor format. (Default: None)
approx: str or None. If not None must by "kron_indep". The Fisher
approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, reuse an existing FisherBlock. If False,
- create a new FisherBlock. If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds inputs and outputs as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
+ word 'use' here has a completely different meaning to "use in the graph"
+ as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+ (Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
@@ -1070,7 +1126,7 @@ class LayerCollection(object):
vocab_size = int(params.shape[0])
block = self.register_block(
- params, block_type(self, vocab_size), reuse=reuse)
+ params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
self._add_uses(params, len(inputs))
@@ -1093,9 +1149,10 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
- reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
- If False, create a new FisherBlock. If VARIABLE_SCOPE, use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'logits' as an additional
+ mini-batch/tower of inputs to the loss-function/predictive distribution
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
"""
loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
seed=seed)
@@ -1126,9 +1183,10 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
- reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
- If False, create a new FisherBlock. If VARIABLE_SCOPE, use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'mean' and 'var' as an additional
+ mini-batch/tower of inputs to the loss-function/predictive distribution
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
"""
loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
seed=seed)
@@ -1154,9 +1212,10 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
- reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
- If False, create a new FisherBlock. If VARIABLE_SCOPE, use
- tf.get_variable_scope().reuse.
+ reuse: bool or str. If True, this adds 'logits' as an additional
+ mini-batch/tower of inputs to the loss-function/predictive distribution
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
"""
loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
seed=seed)
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index c589b18193..c9de0c7270 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -667,6 +667,9 @@ class PartitionedTensor(object):
def __ne__(self, other):
return not self == other # pylint: disable=g-comparison-negation
+ def __getitem__(self, key):
+ return self.as_tensor()[key]
+
def as_tensor(self, dtype=None, name=None, as_ref=False):
with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
assert not as_ref