diff options
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/fisher_factors.py | 82 |
1 files changed, 44 insertions, 38 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index a069f6bdd9..f59168cbc0 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -365,28 +365,16 @@ class InverseProvidingFactor(FisherFactor): dtype=self._dtype) self._matpower_by_exp_and_damping[(exp, damping)] = matpower - def register_eigendecomp(self): - """Registers an eigendecomposition. - - Unlike register_damp_inverse and register_matpower this doesn't create - any variables or inverse ops. Instead it merely makes tensors containing - the eigendecomposition available to anyone that wants them. They will be - recomputed (once) for each session.run() call (when they needed by some op). - """ - if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) - - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least FLAGS.eigenvalue_clipping_threshold - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) - self._eigendecomp = (clipped_eigenvalues, eigenvectors) - def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] + # We do this to ensure that we don't reuse the eigendecomp from old calls + # to make_inverse_update_ops that may be placed on different devices. This + # can happen is the user has both a permanent and lazily constructed + # version of the inverse ops (and only uses one of them). + self.reset_eigendecomp() + num_inverses = len(self._inverses_by_damping) matrix_power_registered = bool(self._matpower_by_exp_and_damping) use_eig = ( @@ -394,8 +382,7 @@ class InverseProvidingFactor(FisherFactor): num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) if use_eig: - self.register_eigendecomp() # ensures self._eigendecomp is set - eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence + eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence for damping, inv in self._inverses_by_damping.items(): ops.append( @@ -430,11 +417,25 @@ class InverseProvidingFactor(FisherFactor): return self._matpower_by_exp_and_damping[(exp, damping)] def get_eigendecomp(self): + """Creates or retrieves eigendecomposition of self._cov.""" # Unlike get_inverse and get_matpower this doesn't retrieve a stored # variable, but instead always computes a fresh version from the current # value of get_cov(). + if not self._eigendecomp: + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) + return self._eigendecomp + def reset_eigendecomp(self): + self._eigendecomp = None + class FullFactor(InverseProvidingFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. @@ -661,25 +662,32 @@ class ConvDiagonalFactor(DiagonalFactor): def _dtype(self): return self._outputs_grads[0].dtype - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - if self._patches is None: - filter_height, filter_width, _, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms - # of memory use. - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) + def make_covariance_update_op(self, ema_decay): + with maybe_colocate_with(self._inputs): + filter_height, filter_width, _, _ = self._filter_shape - if self._has_bias: - patches = append_homog(patches) + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) + + if self._has_bias: + patches = append_homog(patches) + + self._patches = patches - self._patches = patches + op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + self._patches = None + + return op + + def _compute_new_cov(self, idx=0): + with maybe_colocate_with(self._outputs_grads[idx]): outputs_grad = self._outputs_grads[idx] batch_size = array_ops.shape(self._patches)[0] @@ -1009,7 +1017,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor): def register_option1quants(self, damping): - self.register_eigendecomp() self.register_cov_dt1() if damping not in self._option1quants_by_damping: @@ -1035,7 +1042,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor): def register_option2quants(self, damping): - self.register_eigendecomp() self.register_cov_dt1() if damping not in self._option2quants_by_damping: |