From 5c080afd1eda3d631f24b96750d3ec9c794144ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 17 Jan 2018 13:58:05 -0800 Subject: K-FAC: Fixes problem where initial eigendecomposition used by RNN classes would be computed during the call to instantiate_factors (via the "registration" functions) instead of the call to make_inverse_update_ops. This messed up the device placement of these ops and interacted badly with other parts of the code. Also changed how covariance op creation is done in ConvDiagonalFactor in anticipation of similar problems in the future. As of this CL, none of the op creation methods will modify the state of the class, and no ops will be created outside of the op creation methods. We should try to follow this convention going forward. PiperOrigin-RevId: 182265266 --- .../contrib/kfac/python/ops/fisher_factors.py | 82 ++++++++++++---------- 1 file changed, 44 insertions(+), 38 deletions(-) (limited to 'tensorflow/contrib/kfac') diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index a069f6bdd9..f59168cbc0 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -365,28 +365,16 @@ class InverseProvidingFactor(FisherFactor): dtype=self._dtype) self._matpower_by_exp_and_damping[(exp, damping)] = matpower - def register_eigendecomp(self): - """Registers an eigendecomposition. - - Unlike register_damp_inverse and register_matpower this doesn't create - any variables or inverse ops. Instead it merely makes tensors containing - the eigendecomposition available to anyone that wants them. They will be - recomputed (once) for each session.run() call (when they needed by some op). - """ - if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) - - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least FLAGS.eigenvalue_clipping_threshold - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) - self._eigendecomp = (clipped_eigenvalues, eigenvectors) - def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] + # We do this to ensure that we don't reuse the eigendecomp from old calls + # to make_inverse_update_ops that may be placed on different devices. This + # can happen is the user has both a permanent and lazily constructed + # version of the inverse ops (and only uses one of them). + self.reset_eigendecomp() + num_inverses = len(self._inverses_by_damping) matrix_power_registered = bool(self._matpower_by_exp_and_damping) use_eig = ( @@ -394,8 +382,7 @@ class InverseProvidingFactor(FisherFactor): num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) if use_eig: - self.register_eigendecomp() # ensures self._eigendecomp is set - eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence + eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence for damping, inv in self._inverses_by_damping.items(): ops.append( @@ -430,11 +417,25 @@ class InverseProvidingFactor(FisherFactor): return self._matpower_by_exp_and_damping[(exp, damping)] def get_eigendecomp(self): + """Creates or retrieves eigendecomposition of self._cov.""" # Unlike get_inverse and get_matpower this doesn't retrieve a stored # variable, but instead always computes a fresh version from the current # value of get_cov(). + if not self._eigendecomp: + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) + return self._eigendecomp + def reset_eigendecomp(self): + self._eigendecomp = None + class FullFactor(InverseProvidingFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. @@ -661,25 +662,32 @@ class ConvDiagonalFactor(DiagonalFactor): def _dtype(self): return self._outputs_grads[0].dtype - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - if self._patches is None: - filter_height, filter_width, _, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms - # of memory use. - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) + def make_covariance_update_op(self, ema_decay): + with maybe_colocate_with(self._inputs): + filter_height, filter_width, _, _ = self._filter_shape - if self._has_bias: - patches = append_homog(patches) + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) + + if self._has_bias: + patches = append_homog(patches) + + self._patches = patches - self._patches = patches + op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + self._patches = None + + return op + + def _compute_new_cov(self, idx=0): + with maybe_colocate_with(self._outputs_grads[idx]): outputs_grad = self._outputs_grads[idx] batch_size = array_ops.shape(self._patches)[0] @@ -1009,7 +1017,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor): def register_option1quants(self, damping): - self.register_eigendecomp() self.register_cov_dt1() if damping not in self._option1quants_by_damping: @@ -1035,7 +1042,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor): def register_option2quants(self, damping): - self.register_eigendecomp() self.register_cov_dt1() if damping not in self._option2quants_by_damping: -- cgit v1.2.3