aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py82
1 files changed, 44 insertions, 38 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index a069f6bdd9..f59168cbc0 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -365,28 +365,16 @@ class InverseProvidingFactor(FisherFactor):
dtype=self._dtype)
self._matpower_by_exp_and_damping[(exp, damping)] = matpower
- def register_eigendecomp(self):
- """Registers an eigendecomposition.
-
- Unlike register_damp_inverse and register_matpower this doesn't create
- any variables or inverse ops. Instead it merely makes tensors containing
- the eigendecomposition available to anyone that wants them. They will be
- recomputed (once) for each session.run() call (when they needed by some op).
- """
- if not self._eigendecomp:
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
-
- # The matrix self._cov is positive semidefinite by construction, but the
- # numerical eigenvalues could be negative due to numerical errors, so here
- # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
- clipped_eigenvalues = math_ops.maximum(eigenvalues,
- EIGENVALUE_CLIPPING_THRESHOLD)
- self._eigendecomp = (clipped_eigenvalues, eigenvectors)
-
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
ops = []
+ # We do this to ensure that we don't reuse the eigendecomp from old calls
+ # to make_inverse_update_ops that may be placed on different devices. This
+ # can happen is the user has both a permanent and lazily constructed
+ # version of the inverse ops (and only uses one of them).
+ self.reset_eigendecomp()
+
num_inverses = len(self._inverses_by_damping)
matrix_power_registered = bool(self._matpower_by_exp_and_damping)
use_eig = (
@@ -394,8 +382,7 @@ class InverseProvidingFactor(FisherFactor):
num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
if use_eig:
- self.register_eigendecomp() # ensures self._eigendecomp is set
- eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence
+ eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
for damping, inv in self._inverses_by_damping.items():
ops.append(
@@ -430,11 +417,25 @@ class InverseProvidingFactor(FisherFactor):
return self._matpower_by_exp_and_damping[(exp, damping)]
def get_eigendecomp(self):
+ """Creates or retrieves eigendecomposition of self._cov."""
# Unlike get_inverse and get_matpower this doesn't retrieve a stored
# variable, but instead always computes a fresh version from the current
# value of get_cov().
+ if not self._eigendecomp:
+ eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
+
+ # The matrix self._cov is positive semidefinite by construction, but the
+ # numerical eigenvalues could be negative due to numerical errors, so here
+ # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
+ clipped_eigenvalues = math_ops.maximum(eigenvalues,
+ EIGENVALUE_CLIPPING_THRESHOLD)
+ self._eigendecomp = (clipped_eigenvalues, eigenvectors)
+
return self._eigendecomp
+ def reset_eigendecomp(self):
+ self._eigendecomp = None
+
class FullFactor(InverseProvidingFactor):
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
@@ -661,25 +662,32 @@ class ConvDiagonalFactor(DiagonalFactor):
def _dtype(self):
return self._outputs_grads[0].dtype
- def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._outputs_grads[idx]):
- if self._patches is None:
- filter_height, filter_width, _, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms
- # of memory use.
- patches = array_ops.extract_image_patches(
- self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
- padding=self._padding)
+ def make_covariance_update_op(self, ema_decay):
+ with maybe_colocate_with(self._inputs):
+ filter_height, filter_width, _, _ = self._filter_shape
- if self._has_bias:
- patches = append_homog(patches)
+ # TODO(b/64144716): there is potential here for a big savings in terms
+ # of memory use.
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
+ padding=self._padding)
+
+ if self._has_bias:
+ patches = append_homog(patches)
+
+ self._patches = patches
- self._patches = patches
+ op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
+ self._patches = None
+
+ return op
+
+ def _compute_new_cov(self, idx=0):
+ with maybe_colocate_with(self._outputs_grads[idx]):
outputs_grad = self._outputs_grads[idx]
batch_size = array_ops.shape(self._patches)[0]
@@ -1009,7 +1017,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
def register_option1quants(self, damping):
- self.register_eigendecomp()
self.register_cov_dt1()
if damping not in self._option1quants_by_damping:
@@ -1035,7 +1042,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
def register_option2quants(self, damping):
- self.register_eigendecomp()
self.register_cov_dt1()
if damping not in self._option2quants_by_damping: