aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-17 13:58:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-17 14:01:58 -0800
commit5c080afd1eda3d631f24b96750d3ec9c794144ee (patch)
tree4c6a3f928e39c661fbfd9d139d0d2ce9cb1de081 /tensorflow/contrib/kfac
parent9eb734b94773fe5422b39b66f1a704b7934167d4 (diff)
K-FAC: Fixes problem where initial eigendecomposition used by RNN classes would be computed during the call to instantiate_factors (via the "registration" functions) instead of the call to make_inverse_update_ops. This messed up the device placement of these ops and interacted badly with other parts of the code.
Also changed how covariance op creation is done in ConvDiagonalFactor in anticipation of similar problems in the future. As of this CL, none of the op creation methods will modify the state of the class, and no ops will be created outside of the op creation methods. We should try to follow this convention going forward. PiperOrigin-RevId: 182265266
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py82
1 files changed, 44 insertions, 38 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index a069f6bdd9..f59168cbc0 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -365,28 +365,16 @@ class InverseProvidingFactor(FisherFactor):
dtype=self._dtype)
self._matpower_by_exp_and_damping[(exp, damping)] = matpower
- def register_eigendecomp(self):
- """Registers an eigendecomposition.
-
- Unlike register_damp_inverse and register_matpower this doesn't create
- any variables or inverse ops. Instead it merely makes tensors containing
- the eigendecomposition available to anyone that wants them. They will be
- recomputed (once) for each session.run() call (when they needed by some op).
- """
- if not self._eigendecomp:
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
-
- # The matrix self._cov is positive semidefinite by construction, but the
- # numerical eigenvalues could be negative due to numerical errors, so here
- # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
- clipped_eigenvalues = math_ops.maximum(eigenvalues,
- EIGENVALUE_CLIPPING_THRESHOLD)
- self._eigendecomp = (clipped_eigenvalues, eigenvectors)
-
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
ops = []
+ # We do this to ensure that we don't reuse the eigendecomp from old calls
+ # to make_inverse_update_ops that may be placed on different devices. This
+ # can happen is the user has both a permanent and lazily constructed
+ # version of the inverse ops (and only uses one of them).
+ self.reset_eigendecomp()
+
num_inverses = len(self._inverses_by_damping)
matrix_power_registered = bool(self._matpower_by_exp_and_damping)
use_eig = (
@@ -394,8 +382,7 @@ class InverseProvidingFactor(FisherFactor):
num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
if use_eig:
- self.register_eigendecomp() # ensures self._eigendecomp is set
- eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence
+ eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
for damping, inv in self._inverses_by_damping.items():
ops.append(
@@ -430,11 +417,25 @@ class InverseProvidingFactor(FisherFactor):
return self._matpower_by_exp_and_damping[(exp, damping)]
def get_eigendecomp(self):
+ """Creates or retrieves eigendecomposition of self._cov."""
# Unlike get_inverse and get_matpower this doesn't retrieve a stored
# variable, but instead always computes a fresh version from the current
# value of get_cov().
+ if not self._eigendecomp:
+ eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
+
+ # The matrix self._cov is positive semidefinite by construction, but the
+ # numerical eigenvalues could be negative due to numerical errors, so here
+ # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
+ clipped_eigenvalues = math_ops.maximum(eigenvalues,
+ EIGENVALUE_CLIPPING_THRESHOLD)
+ self._eigendecomp = (clipped_eigenvalues, eigenvectors)
+
return self._eigendecomp
+ def reset_eigendecomp(self):
+ self._eigendecomp = None
+
class FullFactor(InverseProvidingFactor):
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
@@ -661,25 +662,32 @@ class ConvDiagonalFactor(DiagonalFactor):
def _dtype(self):
return self._outputs_grads[0].dtype
- def _compute_new_cov(self, idx=0):
- with maybe_colocate_with(self._outputs_grads[idx]):
- if self._patches is None:
- filter_height, filter_width, _, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms
- # of memory use.
- patches = array_ops.extract_image_patches(
- self._inputs,
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=[1, 1, 1, 1],
- padding=self._padding)
+ def make_covariance_update_op(self, ema_decay):
+ with maybe_colocate_with(self._inputs):
+ filter_height, filter_width, _, _ = self._filter_shape
- if self._has_bias:
- patches = append_homog(patches)
+ # TODO(b/64144716): there is potential here for a big savings in terms
+ # of memory use.
+ patches = array_ops.extract_image_patches(
+ self._inputs,
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=[1, 1, 1, 1],
+ padding=self._padding)
+
+ if self._has_bias:
+ patches = append_homog(patches)
+
+ self._patches = patches
- self._patches = patches
+ op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
+ self._patches = None
+
+ return op
+
+ def _compute_new_cov(self, idx=0):
+ with maybe_colocate_with(self._outputs_grads[idx]):
outputs_grad = self._outputs_grads[idx]
batch_size = array_ops.shape(self._patches)[0]
@@ -1009,7 +1017,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
def register_option1quants(self, damping):
- self.register_eigendecomp()
self.register_cov_dt1()
if damping not in self._option1quants_by_damping:
@@ -1035,7 +1042,6 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
def register_option2quants(self, damping):
- self.register_eigendecomp()
self.register_cov_dt1()
if damping not in self._option2quants_by_damping: