diff options
5 files changed, 167 insertions, 184 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 70e56db055..a2665b9279 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -35,18 +35,27 @@ from tensorflow.python.platform import test class MaybeColocateTest(test.TestCase): + def setUp(self): + self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS + + def tearDown(self): + ff.set_global_constants( + colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs) + def testFalse(self): + ff.set_global_constants(colocate_cov_ops_with_inputs=False) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a, False): + with ff._maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@b'], b.op.colocation_groups()) def testTrue(self): + ff.set_global_constants(colocate_cov_ops_with_inputs=True) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a, True): + with ff._maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@a'], b.op.colocation_groups()) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 2c00fc14d9..826e8b7732 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -52,11 +52,15 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 +# Colocate the covariance ops and variables with the input tensors for each +# factor. +COLOCATE_COV_OPS_WITH_INPUTS = True + @contextlib.contextmanager -def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): - """Context to colocate with `op` if `colocate_cov_ops_with_inputs`.""" - if colocate_cov_ops_with_inputs: +def _maybe_colocate_with(op): + """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`.""" + if COLOCATE_COV_OPS_WITH_INPUTS: if isinstance(op, (list, tuple)): with tf_ops.colocate_with(op[0]): yield @@ -70,12 +74,14 @@ def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): def set_global_constants(init_covariances_at_zero=None, zero_debias=None, eigenvalue_decomposition_threshold=None, - eigenvalue_clipping_threshold=None): + eigenvalue_clipping_threshold=None, + colocate_cov_ops_with_inputs=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD + global COLOCATE_COV_OPS_WITH_INPUTS if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero @@ -85,6 +91,8 @@ def set_global_constants(init_covariances_at_zero=None, EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + if colocate_cov_ops_with_inputs is not None: + COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument @@ -264,15 +272,22 @@ class FisherFactor(object): Returns: An Op for updating the covariance Variable referenced by _cov. """ - new_cov = math_ops.add_n( - tuple(self._compute_new_cov(idx) for idx in range(self._num_sources))) - - # Synchronize value across all TPU cores. - if utils.on_tpu(): - new_cov = utils.cross_replica_mean(new_cov) - - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + new_cov_contribs = tuple(self._compute_new_cov(idx) + for idx in range(self._num_sources)) + # This gets the job done but we might want a better solution in the future. + # In particular, we could have a separate way of specifying where the + # the cov variables finally end up, independent of where their various + # contributions are computed. Right now these are the same thing, but in + # the future we might want to perform the cov computations on each tower, + # so that each tower will be considered a "source" (allowing us to reuse + # the existing "source" code for this). + with _maybe_colocate_with(new_cov_contribs[0]): + new_cov = math_ops.add_n(new_cov_contribs) + # Synchronize value across all TPU cores. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) @abc.abstractmethod def make_inverse_update_ops(self): @@ -430,45 +445,38 @@ class FullFactor(InverseProvidingFactor): def __init__(self, params_grads, - batch_size, - colocate_cov_ops_with_inputs=False): + batch_size): self._batch_size = batch_size - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - self._orig_params_grads_name = scope_string_from_params( - [params_grads, self._batch_size]) - params_grads_flat = [] - for params_grad in params_grads: - with _maybe_colocate_with(params_grad, - self._colocate_cov_ops_with_inputs): - col = utils.tensors_to_column(params_grad) - params_grads_flat.append(col) - self._params_grads_flat = tuple(params_grads_flat) + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) super(FullFactor, self).__init__() @property def _var_scope(self): - return "ff_full/" + self._orig_params_grads_name + return "ff_full/" + scope_string_from_params( + [self._params_grads, self._batch_size]) @property def _cov_shape(self): - size = self._params_grads_flat[0].shape[0] - return [size, size] + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, size) @property def _num_sources(self): - return len(self._params_grads_flat) + return len(self._params_grads) @property def _dtype(self): - return self._params_grads_flat[0].dtype + return self._params_grads[0][0].dtype def _compute_new_cov(self, idx=0): # This will be a very basic rank 1 estimate - with _maybe_colocate_with(self._params_grads_flat[idx], - self._colocate_cov_ops_with_inputs): - return ((self._params_grads_flat[idx] * array_ops.transpose( - self._params_grads_flat[idx])) / math_ops.cast( - self._batch_size, self._params_grads_flat[idx].dtype)) + with _maybe_colocate_with(self._params_grads[idx]): + params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) class DiagonalFactor(FisherFactor): @@ -494,28 +502,22 @@ class NaiveDiagonalFactor(DiagonalFactor): def __init__(self, params_grads, - batch_size, - colocate_cov_ops_with_inputs=False): + batch_size): + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) self._batch_size = batch_size - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - params_grads_flat = [] - for params_grad in params_grads: - with _maybe_colocate_with(params_grad, - self._colocate_cov_ops_with_inputs): - col = utils.tensors_to_column(params_grad) - params_grads_flat.append(col) - self._params_grads = tuple(params_grads_flat) - self._orig_params_grads_name = scope_string_from_params( - [self._params_grads, self._batch_size]) super(NaiveDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_naivediag/" + self._orig_params_grads_name + return "ff_naivediag/" + scope_string_from_params( + [self._params_grads, self._batch_size]) @property def _cov_shape(self): - return self._params_grads[0].shape + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, 1) @property def _num_sources(self): @@ -523,13 +525,13 @@ class NaiveDiagonalFactor(DiagonalFactor): @property def _dtype(self): - return self._params_grads[0].dtype + return self._params_grads[0][0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._params_grads[idx], - self._colocate_cov_ops_with_inputs): - return (math_ops.square(self._params_grads[idx]) / math_ops.cast( - self._batch_size, self._params_grads[idx].dtype)) + with _maybe_colocate_with(self._params_grads[idx]): + params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -543,13 +545,10 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): where the square is taken element-wise. """ - # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Instantiate FullyConnectedDiagonalFactor. Args: @@ -558,32 +557,24 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): outputs_grads: List of Tensors of shape [batch_size, output_size]. Gradient of loss with respect to layer's preactivations. has_bias: bool. If True, append '1' to each input. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ + self._inputs = inputs + self._has_bias = has_bias self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._batch_size = array_ops.shape(inputs)[0] - self._orig_tensors_name = scope_string_from_params( - (inputs,) + tuple(outputs_grads)) - - # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. (Only - # the target entry of _outputs_grads changes with idx.) - with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): - if has_bias: - inputs = _append_homog(inputs) - self._squared_inputs = math_ops.square(inputs) + self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_diagfc/" + self._orig_tensors_name + return "ff_diagfc/" + scope_string_from_params( + (self._inputs,) + tuple(self._outputs_grads)) @property def _cov_shape(self): - return [self._squared_inputs.shape[1], self._outputs_grads[0].shape[1]] + return [self._inputs.shape[1] + self._has_bias, + self._outputs_grads[0].shape[1]] @property def _num_sources(self): @@ -598,8 +589,14 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - with _maybe_colocate_with(self._squared_inputs, - self._colocate_cov_ops_with_inputs): + with _maybe_colocate_with(self._outputs_grads[idx]): + # We only need to compute squared_inputs once + if self._squared_inputs is None: + inputs = self._inputs + if self._has_bias: + inputs = _append_homog(self._inputs) + self._squared_inputs = math_ops.square(inputs) + new_cov = math_ops.matmul( self._squared_inputs, math_ops.square(self._outputs_grads[idx]), @@ -611,16 +608,13 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): class ConvDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" - # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Creates a ConvDiagonalFactor object. Args: @@ -635,42 +629,21 @@ class ConvDiagonalFactor(DiagonalFactor): padding: The padding in this layer (1-D of Tensor length 4). has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ + self._inputs = inputs self._filter_shape = filter_shape + self._strides = strides + self._padding = padding self._has_bias = has_bias self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - - self._orig_tensors_name = scope_string_from_name( - (inputs,) + tuple(outputs_grads)) - - # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. (Only - # the target entry of _outputs_grads changes with idx.) - with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): - filter_height, filter_width, _, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms of - # memory use. - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=strides, - rates=[1, 1, 1, 1], - padding=padding) - - if has_bias: - patches = _append_homog(patches) - - self._patches = patches + self._patches = None super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_convdiag/" + self._orig_tensors_name + return "ff_convdiag/" + scope_string_from_name( + (self._inputs,) + tuple(self._outputs_grads)) @property def _cov_shape(self): @@ -689,8 +662,24 @@ class ConvDiagonalFactor(DiagonalFactor): return self._outputs_grads[0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx], - self._colocate_cov_ops_with_inputs): + with _maybe_colocate_with(self._outputs_grads[idx]): + if self._patches is None: + filter_height, filter_width, _, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) + + if self._has_bias: + patches = _append_homog(patches) + + self._patches = patches + outputs_grad = self._outputs_grads[idx] batch_size = array_ops.shape(self._patches)[0] @@ -714,22 +703,18 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def __init__(self, tensors, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of Tensors of shape [batch_size, n]. Represents either a layer's inputs or its output's gradients. has_bias: bool. If True, append '1' to each row. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(FullyConnectedKroneckerFactor, self).__init__() @property @@ -751,8 +736,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): return self._tensors[0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._tensors[idx], - self._colocate_cov_ops_with_inputs): + with _maybe_colocate_with(self._tensors[idx]): tensor = self._tensors[idx] if self._has_bias: tensor = _append_homog(tensor) @@ -774,8 +758,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): filter_shape, strides, padding, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Initializes ConvInputKroneckerFactor. Args: @@ -787,15 +770,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): width_stride, in_channel_stride]. padding: str. Padding method for layer. "SAME" or "VALID". has_bias: bool. If True, append 1 to in_channel. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ self._filter_shape = filter_shape self._strides = strides self._padding = padding self._has_bias = has_bias self._inputs = inputs - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvInputKroneckerFactor, self).__init__() @property @@ -823,8 +803,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") - # TODO(jamesmartens): factor this patches stuff out into a utility function - with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs): + with _maybe_colocate_with(self._inputs): filter_height, filter_width, in_channels, _ = self._filter_shape # TODO(b/64144716): there is potential here for a big savings in terms of @@ -868,18 +847,15 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False): + def __init__(self, outputs_grads): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: list of Tensors. Each Tensor is of shape [batch_size, height, width, out_channels]. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ self._out_channels = outputs_grads[0].shape.as_list()[3] self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvOutputKroneckerFactor, self).__init__() @property @@ -900,8 +876,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): return self._outputs_grads[0].dtype def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx], - self._colocate_cov_ops_with_inputs): + with _maybe_colocate_with(self._outputs_grads[idx]): # reshaped_tensor below is the matrix DS_l defined in the KFC paper # (tilde omitted over S for clarity). It has shape M|T| x I, where # M = minibatch size, |T| = number of spatial locations, and @@ -920,57 +895,49 @@ class FullyConnectedMultiKF(InverseProvidingFactor): def __init__(self, tensor_lists, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Constructs a new `FullyConnectedMultiKF`. Args: tensor_lists: List of lists of Tensors of shape [batch_size, n]. has_bias: bool. If True, '1' is appended to each row. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ - self._orig_tensors_name = scope_string_from_params(tensor_lists) + self._tensor_lists = tensor_lists + self._has_bias = has_bias self._batch_size = array_ops.shape(tensor_lists[0][0])[0] self._num_timesteps = len(tensor_lists[0]) - - tensors = tuple( - array_ops.concat(tensor_list, 0) for tensor_list in tensor_lists) - if has_bias: - tensors = tuple(_append_homog(tensor) for tensor in tensors) - self._tensors = tensors + self._tensors = [None] * len(tensor_lists) self._cov_dt1 = None self._option1quants_by_damping = {} self._option2quants_by_damping = {} - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(FullyConnectedMultiKF, self).__init__() @property def _var_scope(self): - return "ff_fc_multi/" + self._orig_tensors_name + return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists) @property def _num_sources(self): - return len(self._tensors) + return len(self._tensor_lists) @property def _dtype(self): - return self._tensors[0].dtype + return self._tensor_lists[0][0].dtype def make_covariance_update_op(self, ema_decay): - with _maybe_colocate_with(self._tensors, - self._colocate_cov_ops_with_inputs): - op = super(FullyConnectedMultiKF, - self).make_covariance_update_op(ema_decay) - - if self._cov_dt1 is not None: - new_cov_dt1 = math_ops.add_n( - tuple( - self._compute_new_cov_dt1(idx) - for idx in range(self._num_sources))) + + op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) + for idx in range(self._num_sources)) + + with _maybe_colocate_with(new_cov_dt1_contribs[0]): + new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) + op2 = moving_averages.assign_moving_average( self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) @@ -984,26 +951,35 @@ class FullyConnectedMultiKF(InverseProvidingFactor): return op def _compute_new_cov(self, idx=0): - tensor = self._tensors[idx] - normalizer = self._num_timesteps * self._batch_size - return _compute_cov(tensor, normalizer=normalizer) + with _maybe_colocate_with(self._tensor_lists[idx]): + tensor = array_ops.concat(self._tensor_lists[idx], 0) + if self._has_bias: + tensor = _append_homog(tensor) + # We save these so they can be used by _compute_new_cov_dt1 + self._tensors[idx] = tensor + return _compute_cov(tensor) def _compute_new_cov_dt1(self, idx=0): tensor = self._tensors[idx] - normalizer = self._num_timesteps * self._batch_size - tensor_present = tensor[:-self._batch_size, :] - tensor_future = tensor[self._batch_size:, :] - return _compute_cov( - tensor_future, tensor_right=tensor_present, normalizer=normalizer) + with _maybe_colocate_with(tensor): + # Is there a more elegant way to do this computation? + tensor_present = tensor[:-self._batch_size, :] + tensor_future = tensor[self._batch_size:, :] + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + normalizer = self._num_timesteps * self._batch_size + return _compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=normalizer) @property def _cov_shape(self): - size = self._tensors[0].shape[1] + size = self._tensor_lists[0][0].shape[1] + self._has_bias return [size, size] @property def _vec_shape(self): - size = self._tensors[0].shape[1] + size = self._tensor_lists[0][0].shape[1] + self._has_bias return [size] def get_option1quants(self, damping): diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index ca42afe6fb..8d450f04f3 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -76,14 +76,6 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { VARIABLE_SCOPE = "VARIABLE_SCOPE" -def ensure_sequence(obj): - """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" - if isinstance(obj, (tuple, list)): - return obj - else: - return (obj,) - - class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -142,7 +134,6 @@ class LayerCollection(object): def __init__(self, graph=None, - colocate_cov_ops_with_inputs=False, name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() @@ -156,7 +147,6 @@ class LayerCollection(object): self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( APPROX_KRONECKER_SERIES_2_NAME) - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -169,7 +159,7 @@ class LayerCollection(object): @property def registered_variables(self): """A tuple of all of the variables currently registered.""" - tuple_of_tuples = (ensure_sequence(key) for key, block + tuple_of_tuples = (utils.ensure_sequence(key) for key, block in six.iteritems(self.fisher_blocks)) flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) return flat_tuple @@ -276,9 +266,9 @@ class LayerCollection(object): variable_to_block = { var: (params, block) for (params, block) in self.fisher_blocks.items() - for var in ensure_sequence(params) + for var in utils.ensure_sequence(params) } - for variable in ensure_sequence(layer_key): + for variable in utils.ensure_sequence(layer_key): if variable in variable_to_block: prev_key, prev_block = variable_to_block[variable] raise ValueError( @@ -301,7 +291,7 @@ class LayerCollection(object): block.num_inputs()*block.num_registered_minibatches if isinstance( block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) else block.num_registered_minibatches) - key = ensure_sequence(key) + key = utils.ensure_sequence(key) for k in key: vars_to_uses[k] += n return vars_to_uses @@ -382,12 +372,12 @@ class LayerCollection(object): ValueError: If the parameters were already registered in a layer or identified as part of an incompatible group. """ - params = frozenset(ensure_sequence(params)) + params = frozenset(utils.ensure_sequence(params)) # Check if any of the variables in 'params' is already in # 'self.fisher_blocks.keys()'. for registered_params, fisher_block in self.fisher_blocks.items(): - registered_params_set = set(ensure_sequence(registered_params)) + registered_params_set = set(utils.ensure_sequence(registered_params)) for variable in params: if (variable in registered_params_set and params != registered_params_set): @@ -421,7 +411,7 @@ class LayerCollection(object): def _get_linked_approx(self, params): """If params were linked, return their specified approximation.""" - params_set = frozenset(ensure_sequence(params)) + params_set = frozenset(utils.ensure_sequence(params)) if params_set in self.linked_parameters: return self.linked_parameters[params_set] else: @@ -727,7 +717,6 @@ class LayerCollection(object): key = cls, args if key not in self.fisher_factors: - colo = self._colocate_cov_ops_with_inputs with variable_scope.variable_scope(self._var_scope): - self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo) + self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index 48b191ef50..d717f427e6 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -344,5 +344,13 @@ def cross_replica_mean(tensor, name=None): return tpu_ops.cross_replica_sum(tensor / num_shards) +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index 8903c90fbc..074dc579da 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -37,6 +37,7 @@ _allowed_symbols = [ "SubGraph", "generate_random_signs", "fwd_gradients", + "ensure_sequence", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |