aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-12-31 09:53:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-31 09:57:05 -0800
commit13fc601fa38c05d9384dbf657d0ec0555c03e140 (patch)
tree7abcaba50f189123a939423c13155544b9de86a7
parent2cdbdee8f4059f65e3dd7f96bb328f790df547f6 (diff)
Design tweaks of Fisher Factor classes. Now they don't declare ops in their constructors, except possibly for making the cov variables (is this considered an op?). This should allow for easier control over device placement of various ops in the future.
Fixed some problems with how colocations were done for ops computed in the base class and in the RNN class. Colocations now controlled with global configuration variable (similar to the rest of the configuration of the FisherFactor classes). PiperOrigin-RevId: 180441903
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py13
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py302
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py27
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py8
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py1
5 files changed, 167 insertions, 184 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 70e56db055..a2665b9279 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -35,18 +35,27 @@ from tensorflow.python.platform import test
class MaybeColocateTest(test.TestCase):
+ def setUp(self):
+ self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS
+
+ def tearDown(self):
+ ff.set_global_constants(
+ colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs)
+
def testFalse(self):
+ ff.set_global_constants(colocate_cov_ops_with_inputs=False)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a, False):
+ with ff._maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@b'], b.op.colocation_groups())
def testTrue(self):
+ ff.set_global_constants(colocate_cov_ops_with_inputs=True)
with tf_ops.Graph().as_default():
a = constant_op.constant([2.0], name='a')
- with ff._maybe_colocate_with(a, True):
+ with ff._maybe_colocate_with(a):
b = constant_op.constant(3.0, name='b')
self.assertEqual([b'loc:@a'], a.op.colocation_groups())
self.assertEqual([b'loc:@a'], b.op.colocation_groups())
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 2c00fc14d9..826e8b7732 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -52,11 +52,15 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
# matrix powers. Must be nonnegative.
EIGENVALUE_CLIPPING_THRESHOLD = 0.0
+# Colocate the covariance ops and variables with the input tensors for each
+# factor.
+COLOCATE_COV_OPS_WITH_INPUTS = True
+
@contextlib.contextmanager
-def _maybe_colocate_with(op, colocate_cov_ops_with_inputs):
- """Context to colocate with `op` if `colocate_cov_ops_with_inputs`."""
- if colocate_cov_ops_with_inputs:
+def _maybe_colocate_with(op):
+ """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`."""
+ if COLOCATE_COV_OPS_WITH_INPUTS:
if isinstance(op, (list, tuple)):
with tf_ops.colocate_with(op[0]):
yield
@@ -70,12 +74,14 @@ def _maybe_colocate_with(op, colocate_cov_ops_with_inputs):
def set_global_constants(init_covariances_at_zero=None,
zero_debias=None,
eigenvalue_decomposition_threshold=None,
- eigenvalue_clipping_threshold=None):
+ eigenvalue_clipping_threshold=None,
+ colocate_cov_ops_with_inputs=None):
"""Sets various global constants used by the classes in this module."""
global INIT_COVARIANCES_AT_ZERO
global ZERO_DEBIAS
global EIGENVALUE_DECOMPOSITION_THRESHOLD
global EIGENVALUE_CLIPPING_THRESHOLD
+ global COLOCATE_COV_OPS_WITH_INPUTS
if init_covariances_at_zero is not None:
INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
@@ -85,6 +91,8 @@ def set_global_constants(init_covariances_at_zero=None,
EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
if eigenvalue_clipping_threshold is not None:
EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
+ if colocate_cov_ops_with_inputs is not None:
+ COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs
def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
@@ -264,15 +272,22 @@ class FisherFactor(object):
Returns:
An Op for updating the covariance Variable referenced by _cov.
"""
- new_cov = math_ops.add_n(
- tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
-
- # Synchronize value across all TPU cores.
- if utils.on_tpu():
- new_cov = utils.cross_replica_mean(new_cov)
-
- return moving_averages.assign_moving_average(
- self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
+ new_cov_contribs = tuple(self._compute_new_cov(idx)
+ for idx in range(self._num_sources))
+ # This gets the job done but we might want a better solution in the future.
+ # In particular, we could have a separate way of specifying where the
+ # the cov variables finally end up, independent of where their various
+ # contributions are computed. Right now these are the same thing, but in
+ # the future we might want to perform the cov computations on each tower,
+ # so that each tower will be considered a "source" (allowing us to reuse
+ # the existing "source" code for this).
+ with _maybe_colocate_with(new_cov_contribs[0]):
+ new_cov = math_ops.add_n(new_cov_contribs)
+ # Synchronize value across all TPU cores.
+ if utils.on_tpu():
+ new_cov = utils.cross_replica_mean(new_cov)
+ return moving_averages.assign_moving_average(
+ self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
@abc.abstractmethod
def make_inverse_update_ops(self):
@@ -430,45 +445,38 @@ class FullFactor(InverseProvidingFactor):
def __init__(self,
params_grads,
- batch_size,
- colocate_cov_ops_with_inputs=False):
+ batch_size):
self._batch_size = batch_size
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
- self._orig_params_grads_name = scope_string_from_params(
- [params_grads, self._batch_size])
- params_grads_flat = []
- for params_grad in params_grads:
- with _maybe_colocate_with(params_grad,
- self._colocate_cov_ops_with_inputs):
- col = utils.tensors_to_column(params_grad)
- params_grads_flat.append(col)
- self._params_grads_flat = tuple(params_grads_flat)
+ self._params_grads = tuple(utils.ensure_sequence(params_grad)
+ for params_grad in params_grads)
super(FullFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_full/" + self._orig_params_grads_name
+ return "ff_full/" + scope_string_from_params(
+ [self._params_grads, self._batch_size])
@property
def _cov_shape(self):
- size = self._params_grads_flat[0].shape[0]
- return [size, size]
+ size = sum(param_grad.shape.num_elements()
+ for param_grad in self._params_grads[0])
+ return (size, size)
@property
def _num_sources(self):
- return len(self._params_grads_flat)
+ return len(self._params_grads)
@property
def _dtype(self):
- return self._params_grads_flat[0].dtype
+ return self._params_grads[0][0].dtype
def _compute_new_cov(self, idx=0):
# This will be a very basic rank 1 estimate
- with _maybe_colocate_with(self._params_grads_flat[idx],
- self._colocate_cov_ops_with_inputs):
- return ((self._params_grads_flat[idx] * array_ops.transpose(
- self._params_grads_flat[idx])) / math_ops.cast(
- self._batch_size, self._params_grads_flat[idx].dtype))
+ with _maybe_colocate_with(self._params_grads[idx]):
+ params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ return ((params_grads_flat * array_ops.transpose(
+ params_grads_flat)) / math_ops.cast(self._batch_size,
+ params_grads_flat.dtype))
class DiagonalFactor(FisherFactor):
@@ -494,28 +502,22 @@ class NaiveDiagonalFactor(DiagonalFactor):
def __init__(self,
params_grads,
- batch_size,
- colocate_cov_ops_with_inputs=False):
+ batch_size):
+ self._params_grads = tuple(utils.ensure_sequence(params_grad)
+ for params_grad in params_grads)
self._batch_size = batch_size
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
- params_grads_flat = []
- for params_grad in params_grads:
- with _maybe_colocate_with(params_grad,
- self._colocate_cov_ops_with_inputs):
- col = utils.tensors_to_column(params_grad)
- params_grads_flat.append(col)
- self._params_grads = tuple(params_grads_flat)
- self._orig_params_grads_name = scope_string_from_params(
- [self._params_grads, self._batch_size])
super(NaiveDiagonalFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_naivediag/" + self._orig_params_grads_name
+ return "ff_naivediag/" + scope_string_from_params(
+ [self._params_grads, self._batch_size])
@property
def _cov_shape(self):
- return self._params_grads[0].shape
+ size = sum(param_grad.shape.num_elements()
+ for param_grad in self._params_grads[0])
+ return (size, 1)
@property
def _num_sources(self):
@@ -523,13 +525,13 @@ class NaiveDiagonalFactor(DiagonalFactor):
@property
def _dtype(self):
- return self._params_grads[0].dtype
+ return self._params_grads[0][0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._params_grads[idx],
- self._colocate_cov_ops_with_inputs):
- return (math_ops.square(self._params_grads[idx]) / math_ops.cast(
- self._batch_size, self._params_grads[idx].dtype))
+ with _maybe_colocate_with(self._params_grads[idx]):
+ params_grads_flat = utils.tensors_to_column(self._params_grads[idx])
+ return (math_ops.square(params_grads_flat) / math_ops.cast(
+ self._batch_size, params_grads_flat.dtype))
class FullyConnectedDiagonalFactor(DiagonalFactor):
@@ -543,13 +545,10 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
where the square is taken element-wise.
"""
- # TODO(jamesmartens): add units tests for this class
-
def __init__(self,
inputs,
outputs_grads,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Instantiate FullyConnectedDiagonalFactor.
Args:
@@ -558,32 +557,24 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
outputs_grads: List of Tensors of shape [batch_size, output_size].
Gradient of loss with respect to layer's preactivations.
has_bias: bool. If True, append '1' to each input.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
+ self._inputs = inputs
+ self._has_bias = has_bias
self._outputs_grads = outputs_grads
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
self._batch_size = array_ops.shape(inputs)[0]
- self._orig_tensors_name = scope_string_from_params(
- (inputs,) + tuple(outputs_grads))
-
- # Note that we precompute the required operations on the inputs since the
- # inputs don't change with the 'idx' argument to _compute_new_cov. (Only
- # the target entry of _outputs_grads changes with idx.)
- with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs):
- if has_bias:
- inputs = _append_homog(inputs)
- self._squared_inputs = math_ops.square(inputs)
+ self._squared_inputs = None
super(FullyConnectedDiagonalFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_diagfc/" + self._orig_tensors_name
+ return "ff_diagfc/" + scope_string_from_params(
+ (self._inputs,) + tuple(self._outputs_grads))
@property
def _cov_shape(self):
- return [self._squared_inputs.shape[1], self._outputs_grads[0].shape[1]]
+ return [self._inputs.shape[1] + self._has_bias,
+ self._outputs_grads[0].shape[1]]
@property
def _num_sources(self):
@@ -598,8 +589,14 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
# square of an outer product is the outer-product of the entry-wise squares.
# The gradient is the outer product of the input and the output gradients,
# so we just square both and then take their outer-product.
- with _maybe_colocate_with(self._squared_inputs,
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._outputs_grads[idx]):
+ # We only need to compute squared_inputs once
+ if self._squared_inputs is None:
+ inputs = self._inputs
+ if self._has_bias:
+ inputs = _append_homog(self._inputs)
+ self._squared_inputs = math_ops.square(inputs)
+
new_cov = math_ops.matmul(
self._squared_inputs,
math_ops.square(self._outputs_grads[idx]),
@@ -611,16 +608,13 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
class ConvDiagonalFactor(DiagonalFactor):
"""FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
- # TODO(jamesmartens): add units tests for this class
-
def __init__(self,
inputs,
outputs_grads,
filter_shape,
strides,
padding,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Creates a ConvDiagonalFactor object.
Args:
@@ -635,42 +629,21 @@ class ConvDiagonalFactor(DiagonalFactor):
padding: The padding in this layer (1-D of Tensor length 4).
has_bias: Python bool. If True, the layer is assumed to have a bias
parameter in addition to its filter parameter.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
+ self._inputs = inputs
self._filter_shape = filter_shape
+ self._strides = strides
+ self._padding = padding
self._has_bias = has_bias
self._outputs_grads = outputs_grads
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
-
- self._orig_tensors_name = scope_string_from_name(
- (inputs,) + tuple(outputs_grads))
-
- # Note that we precompute the required operations on the inputs since the
- # inputs don't change with the 'idx' argument to _compute_new_cov. (Only
- # the target entry of _outputs_grads changes with idx.)
- with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs):
- filter_height, filter_width, _, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms of
- # memory use.
- patches = array_ops.extract_image_patches(
- inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=strides,
- rates=[1, 1, 1, 1],
- padding=padding)
-
- if has_bias:
- patches = _append_homog(patches)
-
- self._patches = patches
+ self._patches = None
super(ConvDiagonalFactor, self).__init__()
@property
def _var_scope(self):
- return "ff_convdiag/" + self._orig_tensors_name
+ return "ff_convdiag/" + scope_string_from_name(
+ (self._inputs,) + tuple(self._outputs_grads))
@property
def _cov_shape(self):
@@ -689,8 +662,24 @@ class ConvDiagonalFactor(DiagonalFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx],
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._outputs_grads[idx]):
+ if self._patches is None:
+ filter_height, filter_width, _, _ = self._filter_shape
+
+ # TODO(b/64144716): there is potential here for a big savings in terms
+ # of memory use.
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
+ padding=self._padding)
+
+ if self._has_bias:
+ patches = _append_homog(patches)
+
+ self._patches = patches
+
outputs_grad = self._outputs_grads[idx]
batch_size = array_ops.shape(self._patches)[0]
@@ -714,22 +703,18 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
def __init__(self,
tensors,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Instantiate FullyConnectedKroneckerFactor.
Args:
tensors: List of Tensors of shape [batch_size, n]. Represents either a
layer's inputs or its output's gradients.
has_bias: bool. If True, append '1' to each row.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
# The tensor argument is either a tensor of input activations or a tensor of
# output pre-activation gradients.
self._has_bias = has_bias
self._tensors = tensors
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(FullyConnectedKroneckerFactor, self).__init__()
@property
@@ -751,8 +736,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return self._tensors[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._tensors[idx],
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._tensors[idx]):
tensor = self._tensors[idx]
if self._has_bias:
tensor = _append_homog(tensor)
@@ -774,8 +758,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
filter_shape,
strides,
padding,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Initializes ConvInputKroneckerFactor.
Args:
@@ -787,15 +770,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
width_stride, in_channel_stride].
padding: str. Padding method for layer. "SAME" or "VALID".
has_bias: bool. If True, append 1 to in_channel.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
self._has_bias = has_bias
self._inputs = inputs
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(ConvInputKroneckerFactor, self).__init__()
@property
@@ -823,8 +803,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
if idx != 0:
raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
- # TODO(jamesmartens): factor this patches stuff out into a utility function
- with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._inputs):
filter_height, filter_width, in_channels, _ = self._filter_shape
# TODO(b/64144716): there is potential here for a big savings in terms of
@@ -868,18 +847,15 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
Section 3.1 Estimating the factors.
"""
- def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False):
+ def __init__(self, outputs_grads):
"""Initializes ConvOutputKroneckerFactor.
Args:
outputs_grads: list of Tensors. Each Tensor is of shape
[batch_size, height, width, out_channels].
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
self._out_channels = outputs_grads[0].shape.as_list()[3]
self._outputs_grads = outputs_grads
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(ConvOutputKroneckerFactor, self).__init__()
@property
@@ -900,8 +876,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
return self._outputs_grads[0].dtype
def _compute_new_cov(self, idx=0):
- with _maybe_colocate_with(self._outputs_grads[idx],
- self._colocate_cov_ops_with_inputs):
+ with _maybe_colocate_with(self._outputs_grads[idx]):
# reshaped_tensor below is the matrix DS_l defined in the KFC paper
# (tilde omitted over S for clarity). It has shape M|T| x I, where
# M = minibatch size, |T| = number of spatial locations, and
@@ -920,57 +895,49 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
def __init__(self,
tensor_lists,
- has_bias=False,
- colocate_cov_ops_with_inputs=False):
+ has_bias=False):
"""Constructs a new `FullyConnectedMultiKF`.
Args:
tensor_lists: List of lists of Tensors of shape [batch_size, n].
has_bias: bool. If True, '1' is appended to each row.
- colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
- their inputs.
"""
- self._orig_tensors_name = scope_string_from_params(tensor_lists)
+ self._tensor_lists = tensor_lists
+ self._has_bias = has_bias
self._batch_size = array_ops.shape(tensor_lists[0][0])[0]
self._num_timesteps = len(tensor_lists[0])
-
- tensors = tuple(
- array_ops.concat(tensor_list, 0) for tensor_list in tensor_lists)
- if has_bias:
- tensors = tuple(_append_homog(tensor) for tensor in tensors)
- self._tensors = tensors
+ self._tensors = [None] * len(tensor_lists)
self._cov_dt1 = None
self._option1quants_by_damping = {}
self._option2quants_by_damping = {}
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
super(FullyConnectedMultiKF, self).__init__()
@property
def _var_scope(self):
- return "ff_fc_multi/" + self._orig_tensors_name
+ return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists)
@property
def _num_sources(self):
- return len(self._tensors)
+ return len(self._tensor_lists)
@property
def _dtype(self):
- return self._tensors[0].dtype
+ return self._tensor_lists[0][0].dtype
def make_covariance_update_op(self, ema_decay):
- with _maybe_colocate_with(self._tensors,
- self._colocate_cov_ops_with_inputs):
- op = super(FullyConnectedMultiKF,
- self).make_covariance_update_op(ema_decay)
-
- if self._cov_dt1 is not None:
- new_cov_dt1 = math_ops.add_n(
- tuple(
- self._compute_new_cov_dt1(idx)
- for idx in range(self._num_sources)))
+
+ op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay)
+
+ if self._cov_dt1 is not None:
+ new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx)
+ for idx in range(self._num_sources))
+
+ with _maybe_colocate_with(new_cov_dt1_contribs[0]):
+ new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs)
+
op2 = moving_averages.assign_moving_average(
self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
@@ -984,26 +951,35 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
return op
def _compute_new_cov(self, idx=0):
- tensor = self._tensors[idx]
- normalizer = self._num_timesteps * self._batch_size
- return _compute_cov(tensor, normalizer=normalizer)
+ with _maybe_colocate_with(self._tensor_lists[idx]):
+ tensor = array_ops.concat(self._tensor_lists[idx], 0)
+ if self._has_bias:
+ tensor = _append_homog(tensor)
+ # We save these so they can be used by _compute_new_cov_dt1
+ self._tensors[idx] = tensor
+ return _compute_cov(tensor)
def _compute_new_cov_dt1(self, idx=0):
tensor = self._tensors[idx]
- normalizer = self._num_timesteps * self._batch_size
- tensor_present = tensor[:-self._batch_size, :]
- tensor_future = tensor[self._batch_size:, :]
- return _compute_cov(
- tensor_future, tensor_right=tensor_present, normalizer=normalizer)
+ with _maybe_colocate_with(tensor):
+ # Is there a more elegant way to do this computation?
+ tensor_present = tensor[:-self._batch_size, :]
+ tensor_future = tensor[self._batch_size:, :]
+ # We specify a normalizer for this computation to ensure a PSD Fisher
+ # block estimate. This is equivalent to padding with zeros, as was done
+ # in Section B.2 of the appendix.
+ normalizer = self._num_timesteps * self._batch_size
+ return _compute_cov(
+ tensor_future, tensor_right=tensor_present, normalizer=normalizer)
@property
def _cov_shape(self):
- size = self._tensors[0].shape[1]
+ size = self._tensor_lists[0][0].shape[1] + self._has_bias
return [size, size]
@property
def _vec_shape(self):
- size = self._tensors[0].shape[1]
+ size = self._tensor_lists[0][0].shape[1] + self._has_bias
return [size]
def get_option1quants(self, damping):
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index ca42afe6fb..8d450f04f3 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -76,14 +76,6 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
VARIABLE_SCOPE = "VARIABLE_SCOPE"
-def ensure_sequence(obj):
- """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
- if isinstance(obj, (tuple, list)):
- return obj
- else:
- return (obj,)
-
-
class LayerParametersDict(OrderedDict):
"""An OrderedDict where keys are Tensors or tuples of Tensors.
@@ -142,7 +134,6 @@ class LayerCollection(object):
def __init__(self,
graph=None,
- colocate_cov_ops_with_inputs=False,
name="LayerCollection"):
self.fisher_blocks = LayerParametersDict()
self.fisher_factors = OrderedDict()
@@ -156,7 +147,6 @@ class LayerCollection(object):
self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_multi_approximation = (
APPROX_KRONECKER_SERIES_2_NAME)
- self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
with variable_scope.variable_scope(None, default_name=name) as scope:
self._var_scope = scope.name
@@ -169,7 +159,7 @@ class LayerCollection(object):
@property
def registered_variables(self):
"""A tuple of all of the variables currently registered."""
- tuple_of_tuples = (ensure_sequence(key) for key, block
+ tuple_of_tuples = (utils.ensure_sequence(key) for key, block
in six.iteritems(self.fisher_blocks))
flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
return flat_tuple
@@ -276,9 +266,9 @@ class LayerCollection(object):
variable_to_block = {
var: (params, block)
for (params, block) in self.fisher_blocks.items()
- for var in ensure_sequence(params)
+ for var in utils.ensure_sequence(params)
}
- for variable in ensure_sequence(layer_key):
+ for variable in utils.ensure_sequence(layer_key):
if variable in variable_to_block:
prev_key, prev_block = variable_to_block[variable]
raise ValueError(
@@ -301,7 +291,7 @@ class LayerCollection(object):
block.num_inputs()*block.num_registered_minibatches if isinstance(
block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB))
else block.num_registered_minibatches)
- key = ensure_sequence(key)
+ key = utils.ensure_sequence(key)
for k in key:
vars_to_uses[k] += n
return vars_to_uses
@@ -382,12 +372,12 @@ class LayerCollection(object):
ValueError: If the parameters were already registered in a layer or
identified as part of an incompatible group.
"""
- params = frozenset(ensure_sequence(params))
+ params = frozenset(utils.ensure_sequence(params))
# Check if any of the variables in 'params' is already in
# 'self.fisher_blocks.keys()'.
for registered_params, fisher_block in self.fisher_blocks.items():
- registered_params_set = set(ensure_sequence(registered_params))
+ registered_params_set = set(utils.ensure_sequence(registered_params))
for variable in params:
if (variable in registered_params_set and
params != registered_params_set):
@@ -421,7 +411,7 @@ class LayerCollection(object):
def _get_linked_approx(self, params):
"""If params were linked, return their specified approximation."""
- params_set = frozenset(ensure_sequence(params))
+ params_set = frozenset(utils.ensure_sequence(params))
if params_set in self.linked_parameters:
return self.linked_parameters[params_set]
else:
@@ -727,7 +717,6 @@ class LayerCollection(object):
key = cls, args
if key not in self.fisher_factors:
- colo = self._colocate_cov_ops_with_inputs
with variable_scope.variable_scope(self._var_scope):
- self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo)
+ self.fisher_factors[key] = cls(*args)
return self.fisher_factors[key]
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index 48b191ef50..d717f427e6 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -344,5 +344,13 @@ def cross_replica_mean(tensor, name=None):
return tpu_ops.cross_replica_sum(tensor / num_shards)
+def ensure_sequence(obj):
+ """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
+ if isinstance(obj, (tuple, list)):
+ return obj
+ else:
+ return (obj,)
+
+
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index 8903c90fbc..074dc579da 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -37,6 +37,7 @@ _allowed_symbols = [
"SubGraph",
"generate_random_signs",
"fwd_gradients",
+ "ensure_sequence",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)