aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py13
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py302
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py27
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py8
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py1
5 files changed, 167 insertions, 184 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 70e56db055..a2665b9279 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -35,18 +35,27 @@ from tensorflow.python.platform import test
class MaybeColocateTest(test.TestCase):
+ def setUp(self):
+ self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS
+
+ def tearDown(self):
+ ff.set_global_constants(
+ colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs)
+
def testFalse(self):
+ ff.set_global_constants(colocate_cov_ops_with_inputs=False)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a, False):
+ with ff._maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@b'], b.op.colocation_groups())
def testTrue(self):
+ ff.set_global_constants(colocate_cov_ops_with_inputs=True)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a, True):
+ with ff._maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@a'], b.op.colocation_groups())
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 2c00fc14d9..826e8b7732 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -52,11 +52,15 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
# matrix powers. Must be nonnegative.
EIGENVALUE_CLIPPING_THRESHOLD = 0.0
+# Colocate the covariance ops and variables with the input tensors for each
+# factor.
+COLOCATE_COV_OPS_WITH_INPUTS = True
+
@contextlib.contextmanager
-def _maybe_colocate_with(op, colocate_cov_ops_with_inputs):
- """Context to colocate with `op` if `colocate_cov_ops_with_inputs`."""
- if colocate_cov_ops_with_inputs:
+def _maybe_colocate_with(op):
+ """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`."""
+ if COLOCATE_COV_OPS_WITH_INPUTS:
if isinstance(op, (list, tuple)):
with tf_ops.colocate_with(op[0]):
yield
@@ -70,12 +74,14 @@ def _maybe_colocate_with(op, colocate_cov_ops_with_inputs):
def set_global_constants(init_covariances_at_zero=None,
zero_debias=None,
eigenvalue_decomposition_threshold=None,
- eigenvalue_clipping_threshold=None):
+ eigenvalue_clipping_threshold=None,
+ colocate_cov_ops_with_inputs=None):
"""Sets various global constants used by the classes in this module."""
global INIT_COVARIANCES_AT_ZERO
global ZERO_DEBIAS
global EIGENVALUE_DECOMPOSITION_THRESHOLD
global EIGENVALUE_CLIPPING_THRESHOLD
+ global COLOCATE_COV_OPS_WITH_INPUTS
if init_covariances_at_zero is not None:
INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
@@ -85,6 +91,8 @@ def set_global_constants(init_covariances_at_zero=None,
EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
if eigenvalue_clipping_threshold is not None:
EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
+ if colocate_cov_ops_with_inputs is not None:
+ COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs
def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
@@ -264,15 +272,22 @@ class FisherFactor(object):
Returns:
An Op for updating the covariance Variable referenced by _cov.
"""
- new_cov = math_ops.add_n(
- tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
-
- # Synchronize value across all TPU cores.
- if utils.on_tpu():
- new_cov = utils.cross_replica_mean(new_cov)
-
- return moving_averages.assign_moving_average(
- self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
+ new_cov_contribs = tuple(self._compute_new_cov(idx)
+ for idx in range(self._num_sources))
+ # This gets the job done but we might want a better solution in the future.
+ # In particular, we could have a separate way of specifying where the
+ # the cov variables finally end up, independent of where their various
+ # contributions are computed. Right now these are the same thing, but in
+ # the future we might want to perform the cov computations on each tower,
+ # so that each tower will be considered a "source" (allowing us to reuse
+ # the existing "source" code for this).
+ with _maybe_colocate_with(new_cov_contribs[0]):
+ new_cov = math_ops.add_n(new_cov_contribs)
+ # Synchronize value across all TPU cores.
+ if utils.on_tpu():
+ new_cov = utils.cross_replica_mean(new_cov)
+ return moving_averages.assign_moving_average(
+ self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
@abc.abstractmethod
def make_inverse_update_ops(self):
@@ -430,45 +445,38 @@ class FullFactor(InverseProvidingFactor):
def __init__(self,
params_grads,
- batch_size,
- colocate_cov_ops_with_inputs=False):
+ batch_size):
self._batch_size = batch_size
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
- self._orig_params_grads_name = scope_string_from_params(
- [params_grads, self._batch_size])
- params_grads_flat = []
- for params_grad in params_grads:
- with _maybe_colocate_with(params_grad,
- self._colocate_cov_ops_with_inputs):
- col = utils.tensors_to_column(params_grad)
- params_grads_flat.append(col)
- self._params_grads_flat = tuple(params_grads_flat)
+ self._params_grads = tuple(utils.ensure_sequence(params_grad)
+ for params_grad in params_grads)
super(FullFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_full/" + self._orig_params_grads_name
+ return "ff_full/" + scope_string_from_params(
+ [self._params_grads, self._batch_size])
@property
def _cov_shape(self):
- size = self._params_grads_flat[0].shape[0]
- return [size, size]
+ size = sum(param_grad.shape.num_elements()
+ for param_grad in self._params_grads[0])
+ return (size, size)
@property
def _num_sources(self):
- return len(self._params_grads_flat)
+ return len(self._params_grads)
@property
def _dtype(self):
- return self._params_grads_flat[0].dtype
+ return self._params_grads[0][0].dtype
def _compute_new_cov(self, idx=0):
# This will be a very basic rank 1 estimate
- with _maybe_colocate_with(self._params_grads_flat[idx],
- self._colocate_cov_ops_with_inputs):
- return ((self._params_grads_flat[idx] * array_ops.transpose(
- self._params_grads_flat[idx])) / math_ops.cast(
- self._batch_size, self._params_grads_flat[idx].dtype))
+ with _maybe_colocate_with(self._params_grads[idx]):
+ params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ return ((params_grads_flat * array_ops.transpose(
+ params_grads_flat)) / math_ops.cast(self._batch_size,
+ params_grads_flat.dtype))
class DiagonalFactor(FisherFactor):
@@ -494,28 +502,22 @@ class NaiveDiagonalFactor(DiagonalFactor):
def __init__(self,
params_grads,
- batch_size,
- colocate_cov_ops_with_inputs=False):
+ batch_size):
+ self._params_grads = tuple(utils.ensure_sequence(params_grad)
+ for params_grad in params_grads)
self._batch_size = batch_size
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
- params_grads_flat = []
- for params_grad in params_grads:
- with _maybe_colocate_with(params_grad,
- self._colocate_cov_ops_with_inputs):
- col = utils.tensors_to_column(params_grad)
- params_grads_flat.append(col)
- self._params_grads = tuple(params_grads_flat)
- self._orig_params_grads_name = scope_string_from_params(
- [self._params_grads, self._batch_size])
super(NaiveDiagonalFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_naivediag/" + self._orig_params_grads_name
+ return "ff_naivediag/" + scope_string_from_params(
+ [self._params_grads, self._batch_size])
@property
def _cov_shape(self):
- return self._params_grads[0].shape
+ size = sum(param_grad.shape.num_elements()
+ for param_grad in self._params_grads[0])
+ return (size, 1)
@property
def _num_sources(self):
@@ -523,13 +525,13 @@ class NaiveDiagonalFactor(DiagonalFactor):
@property
def _dtype(self):
- return self._params_grads[0].dtype
+ return self._params_grads[0][0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._params_grads[idx],
- self._colocate_cov_ops_with_inputs):
- return (math_ops.square(self._params_grads[idx]) / math_ops.cast(
- self._batch_size, self._params_grads[idx].dtype))
+ with _maybe_colocate_with(self._params_grads[idx]):
+ params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ return (math_ops.square(params_grads_flat) / math_ops.cast(
+ self._batch_size, params_grads_flat.dtype))
class FullyConnectedDiagonalFactor(DiagonalFactor):
@@ -543,13 +545,10 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
where the square is taken element-wise.
"""
- # TODO(jamesmartens): add units tests for this class
-
def __init__(self,
inputs,
outputs_grads,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Instantiate FullyConnectedDiagonalFactor.
Args:
@@ -558,32 +557,24 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
outputs_grads: List of Tensors of shape [batch_size, output_size].
Gradient of loss with respect to layer's preactivations.
has_bias: bool. If True, append '1' to each input.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
+ self._inputs = inputs
+ self._has_bias = has_bias
self._outputs_grads = outputs_grads
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
self._batch_size = array_ops.shape(inputs)[0]
- self._orig_tensors_name = scope_string_from_params(
- (inputs,) + tuple(outputs_grads))
-
- # Note that we precompute the required operations on the inputs since the
- # inputs don't change with the 'idx' argument to _compute_new_cov. (Only
- # the target entry of _outputs_grads changes with idx.)
- with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs):
- if has_bias:
- inputs = _append_homog(inputs)
- self._squared_inputs = math_ops.square(inputs)
+ self._squared_inputs = None
super(FullyConnectedDiagonalFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_diagfc/" + self._orig_tensors_name
+ return "ff_diagfc/" + scope_string_from_params(
+ (self._inputs,) + tuple(self._outputs_grads))
@property
def _cov_shape(self):
- return [self._squared_inputs.shape[1], self._outputs_grads[0].shape[1]]
+ return [self._inputs.shape[1] + self._has_bias,
+ self._outputs_grads[0].shape[1]]
@property
def _num_sources(self):
@@ -598,8 +589,14 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
# square of an outer product is the outer-product of the entry-wise squares.
# The gradient is the outer product of the input and the output gradients,
# so we just square both and then take their outer-product.
- with _maybe_colocate_with(self._squared_inputs,
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._outputs_grads[idx]):
+ # We only need to compute squared_inputs once
+ if self._squared_inputs is None:
+ inputs = self._inputs
+ if self._has_bias:
+ inputs = _append_homog(self._inputs)
+ self._squared_inputs = math_ops.square(inputs)
+
new_cov = math_ops.matmul(
self._squared_inputs,
math_ops.square(self._outputs_grads[idx]),
@@ -611,16 +608,13 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
class ConvDiagonalFactor(DiagonalFactor):
"""FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
- # TODO(jamesmartens): add units tests for this class
-
def __init__(self,
inputs,
outputs_grads,
filter_shape,
strides,
padding,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Creates a ConvDiagonalFactor object.
Args:
@@ -635,42 +629,21 @@ class ConvDiagonalFactor(DiagonalFactor):
padding: The padding in this layer (1-D of Tensor length 4).
has_bias: Python bool. If True, the layer is assumed to have a bias
parameter in addition to its filter parameter.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
+ self._inputs = inputs
self._filter_shape = filter_shape
+ self._strides = strides
+ self._padding = padding
self._has_bias = has_bias
self._outputs_grads = outputs_grads
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
-
- self._orig_tensors_name = scope_string_from_name(
- (inputs,) + tuple(outputs_grads))
-
- # Note that we precompute the required operations on the inputs since the
- # inputs don't change with the 'idx' argument to _compute_new_cov. (Only
- # the target entry of _outputs_grads changes with idx.)
- with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs):
- filter_height, filter_width, _, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms of
- # memory use.
- patches = array_ops.extract_image_patches(
- inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=strides,
- rates=[1, 1, 1, 1],
- padding=padding)
-
- if has_bias:
- patches = _append_homog(patches)
-
- self._patches = patches
+ self._patches = None
super(ConvDiagonalFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_convdiag/" + self._orig_tensors_name
+ return "ff_convdiag/" + scope_string_from_name(
+ (self._inputs,) + tuple(self._outputs_grads))
@property
def _cov_shape(self):
@@ -689,8 +662,24 @@ class ConvDiagonalFactor(DiagonalFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx],
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._outputs_grads[idx]):
+ if self._patches is None:
+ filter_height, filter_width, _, _ = self._filter_shape
+
+ # TODO(b/64144716): there is potential here for a big savings in terms
+ # of memory use.
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
+ padding=self._padding)
+
+ if self._has_bias:
+ patches = _append_homog(patches)
+
+ self._patches = patches
+
outputs_grad = self._outputs_grads[idx]
batch_size = array_ops.shape(self._patches)[0]
@@ -714,22 +703,18 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
def __init__(self,
tensors,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Instantiate FullyConnectedKroneckerFactor.
Args:
tensors: List of Tensors of shape [batch_size, n]. Represents either a
layer's inputs or its output's gradients.
has_bias: bool. If True, append '1' to each row.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
# The tensor argument is either a tensor of input activations or a tensor of
# output pre-activation gradients.
self._has_bias = has_bias
self._tensors = tensors
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(FullyConnectedKroneckerFactor, self).__init__()
@property
@@ -751,8 +736,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return self._tensors[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._tensors[idx],
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._tensors[idx]):
tensor = self._tensors[idx]
if self._has_bias:
tensor = _append_homog(tensor)
@@ -774,8 +758,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
filter_shape,
strides,
padding,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Initializes ConvInputKroneckerFactor.
Args:
@@ -787,15 +770,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
width_stride, in_channel_stride].
padding: str. Padding method for layer. "SAME" or "VALID".
has_bias: bool. If True, append 1 to in_channel.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
self._has_bias = has_bias
self._inputs = inputs
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(ConvInputKroneckerFactor, self).__init__()
@property
@@ -823,8 +803,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
if idx != 0:
raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
- # TODO(jamesmartens): factor this patches stuff out into a utility function
- with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._inputs):
filter_height, filter_width, in_channels, _ = self._filter_shape
# TODO(b/64144716): there is potential here for a big savings in terms of
@@ -868,18 +847,15 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
Section 3.1 Estimating the factors.
"""
- def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False):
+ def __init__(self, outputs_grads):
"""Initializes ConvOutputKroneckerFactor.
Args:
outputs_grads: list of Tensors. Each Tensor is of shape
[batch_size, height, width, out_channels].
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
self._out_channels = outputs_grads[0].shape.as_list()[3]
self._outputs_grads = outputs_grads
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(ConvOutputKroneckerFactor, self).__init__()
@property
@@ -900,8 +876,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx],
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._outputs_grads[idx]):
# reshaped_tensor below is the matrix DS_l defined in the KFC paper
# (tilde omitted over S for clarity). It has shape M|T| x I, where
# M = minibatch size, |T| = number of spatial locations, and
@@ -920,57 +895,49 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
def __init__(self,
tensor_lists,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Constructs a new `FullyConnectedMultiKF`.
Args:
tensor_lists: List of lists of Tensors of shape [batch_size, n].
has_bias: bool. If True, '1' is appended to each row.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
- self._orig_tensors_name = scope_string_from_params(tensor_lists)
+ self._tensor_lists = tensor_lists
+ self._has_bias = has_bias
self._batch_size = array_ops.shape(tensor_lists[0][0])[0]
self._num_timesteps = len(tensor_lists[0])
-
- tensors = tuple(
- array_ops.concat(tensor_list, 0) for tensor_list in tensor_lists)
- if has_bias:
- tensors = tuple(_append_homog(tensor) for tensor in tensors)
- self._tensors = tensors
+ self._tensors = [None] * len(tensor_lists)
self._cov_dt1 = None
self._option1quants_by_damping = {}
self._option2quants_by_damping = {}
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(FullyConnectedMultiKF, self).__init__()
@property
def _var_scope(self):
- return "ff_fc_multi/" + self._orig_tensors_name
+ return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists)
@property
def _num_sources(self):
- return len(self._tensors)
+ return len(self._tensor_lists)
@property
def _dtype(self):
- return self._tensors[0].dtype
+ return self._tensor_lists[0][0].dtype
def make_covariance_update_op(self, ema_decay):
- with _maybe_colocate_with(self._tensors,
- self._colocate_cov_ops_with_inputs):
- op = super(FullyConnectedMultiKF,
- self).make_covariance_update_op(ema_decay)
-
- if self._cov_dt1 is not None:
- new_cov_dt1 = math_ops.add_n(
- tuple(
- self._compute_new_cov_dt1(idx)
- for idx in range(self._num_sources)))
+
+ op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay)
+
+ if self._cov_dt1 is not None:
+ new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx)
+ for idx in range(self._num_sources))
+
+ with _maybe_colocate_with(new_cov_dt1_contribs[0]):
+ new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
+
op2 = moving_averages.assign_moving_average(
self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
@@ -984,26 +951,35 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
return op
def _compute_new_cov(self, idx=0):
- tensor = self._tensors[idx]
- normalizer = self._num_timesteps * self._batch_size
- return _compute_cov(tensor, normalizer=normalizer)
+ with _maybe_colocate_with(self._tensor_lists[idx]):
+ tensor = array_ops.concat(self._tensor_lists[idx], 0)
+ if self._has_bias:
+ tensor = _append_homog(tensor)
+ # We save these so they can be used by _compute_new_cov_dt1
+ self._tensors[idx] = tensor
+ return _compute_cov(tensor)
def _compute_new_cov_dt1(self, idx=0):
tensor = self._tensors[idx]
- normalizer = self._num_timesteps * self._batch_size
- tensor_present = tensor[:-self._batch_size, :]
- tensor_future = tensor[self._batch_size:, :]
- return _compute_cov(
- tensor_future, tensor_right=tensor_present, normalizer=normalizer)
+ with _maybe_colocate_with(tensor):
+ # Is there a more elegant way to do this computation?
+ tensor_present = tensor[:-self._batch_size, :]
+ tensor_future = tensor[self._batch_size:, :]
+ # We specify a normalizer for this computation to ensure a PSD Fisher
+ # block estimate. This is equivalent to padding with zeros, as was done
+ # in Section B.2 of the appendix.
+ normalizer = self._num_timesteps * self._batch_size
+ return _compute_cov(
+ tensor_future, tensor_right=tensor_present, normalizer=normalizer)
@property
def _cov_shape(self):
- size = self._tensors[0].shape[1]
+ size = self._tensor_lists[0][0].shape[1] + self._has_bias
return [size, size]
@property
def _vec_shape(self):
- size = self._tensors[0].shape[1]
+ size = self._tensor_lists[0][0].shape[1] + self._has_bias
return [size]
def get_option1quants(self, damping):
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index ca42afe6fb..8d450f04f3 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -76,14 +76,6 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
VARIABLE_SCOPE = "VARIABLE_SCOPE"
-def ensure_sequence(obj):
- """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
- if isinstance(obj, (tuple, list)):
- return obj
- else:
- return (obj,)
-
-
class LayerParametersDict(OrderedDict):
"""An OrderedDict where keys are Tensors or tuples of Tensors.
@@ -142,7 +134,6 @@ class LayerCollection(object):
def __init__(self,
graph=None,
- colocate_cov_ops_with_inputs=False,
name="LayerCollection"):
self.fisher_blocks = LayerParametersDict()
self.fisher_factors = OrderedDict()
@@ -156,7 +147,6 @@ class LayerCollection(object):
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_multi_approximation = (
APPROX_KRONECKER_SERIES_2_NAME)
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
with variable_scope.variable_scope(None, default_name=name) as scope:
self._var_scope = scope.name
@@ -169,7 +159,7 @@ class LayerCollection(object):
@property
def registered_variables(self):
"""A tuple of all of the variables currently registered."""
- tuple_of_tuples = (ensure_sequence(key) for key, block
+ tuple_of_tuples = (utils.ensure_sequence(key) for key, block
in six.iteritems(self.fisher_blocks))
flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
return flat_tuple
@@ -276,9 +266,9 @@ class LayerCollection(object):
variable_to_block = {
var: (params, block)
for (params, block) in self.fisher_blocks.items()
- for var in ensure_sequence(params)
+ for var in utils.ensure_sequence(params)
}
- for variable in ensure_sequence(layer_key):
+ for variable in utils.ensure_sequence(layer_key):
if variable in variable_to_block:
prev_key, prev_block = variable_to_block[variable]
raise ValueError(
@@ -301,7 +291,7 @@ class LayerCollection(object):
block.num_inputs()*block.num_registered_minibatches if isinstance(
block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB))
else block.num_registered_minibatches)
- key = ensure_sequence(key)
+ key = utils.ensure_sequence(key)
for k in key:
vars_to_uses[k] += n
return vars_to_uses
@@ -382,12 +372,12 @@ class LayerCollection(object):
ValueError: If the parameters were already registered in a layer or
identified as part of an incompatible group.
"""
- params = frozenset(ensure_sequence(params))
+ params = frozenset(utils.ensure_sequence(params))
# Check if any of the variables in 'params' is already in
# 'self.fisher_blocks.keys()'.
for registered_params, fisher_block in self.fisher_blocks.items():
- registered_params_set = set(ensure_sequence(registered_params))
+ registered_params_set = set(utils.ensure_sequence(registered_params))
for variable in params:
if (variable in registered_params_set and
params != registered_params_set):
@@ -421,7 +411,7 @@ class LayerCollection(object):
def _get_linked_approx(self, params):
"""If params were linked, return their specified approximation."""
- params_set = frozenset(ensure_sequence(params))
+ params_set = frozenset(utils.ensure_sequence(params))
if params_set in self.linked_parameters:
return self.linked_parameters[params_set]
else:
@@ -727,7 +717,6 @@ class LayerCollection(object):
key = cls, args
if key not in self.fisher_factors:
- colo = self._colocate_cov_ops_with_inputs
with variable_scope.variable_scope(self._var_scope):
- self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo)
+ self.fisher_factors[key] = cls(*args)
return self.fisher_factors[key]
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index 48b191ef50..d717f427e6 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -344,5 +344,13 @@ def cross_replica_mean(tensor, name=None):
return tpu_ops.cross_replica_sum(tensor / num_shards)
+def ensure_sequence(obj):
+ """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
+ if isinstance(obj, (tuple, list)):
+ return obj
+ else:
+ return (obj,)
+
+
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index 8903c90fbc..074dc579da 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -37,6 +37,7 @@ _allowed_symbols = [
"SubGraph",
"generate_random_signs",
"fwd_gradients",
+ "ensure_sequence",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)