diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-11 15:39:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-11 15:43:28 -0800 |
commit | 9322010d25f0b11c8bc2a498672f270708697425 (patch) | |
tree | 2c0d1c59984bc5735593f2a428b7cb80832def8b /tensorflow/contrib/kfac | |
parent | 733a8ed8918a548867b110c452b67c95dda537e2 (diff) |
K-FAC: Add (cov|inv)_update_(ops|thunks) to FisherEstimator.
PiperOrigin-RevId: 181672525
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py | 114 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/estimator.py | 87 |
3 files changed, 179 insertions, 27 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 4928bf2c10..17458ffa2a 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -17,12 +17,17 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index 9b28c45c72..bfdb69ad02 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils @@ -25,11 +27,15 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] @@ -119,6 +125,114 @@ class EstimatorTest(test.TestCase): estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, mode) + def test_cov_update_thunks(self): + """Ensures covariance update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimator( + variables=[self.weights], + layer_collection=self.layer_collection, + cov_ema_decay=0.0, + damping=0.0) + + # Construct an op that executes one covariance update per step. + global_step = training_util.get_or_create_global_step() + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + cov_update_op_thunks = fisher_estimator.cov_update_thunks + cov_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(cov_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_cov_values = sess.run(cov_matrices) + + # Ensure there's one update per covariance matrix. + self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) + + # Test is no-op if only 1 covariance matrix. + assert len(cov_matrices) > 1 + + for i in range(len(cov_matrices)): + # Compare new and old covariance values + new_cov_values = sess.run(cov_matrices) + is_cov_equal = [ + np.allclose(initial_cov_value, new_cov_value) + for (initial_cov_value, + new_cov_value) in zip(initial_cov_values, new_cov_values) + ] + num_cov_equal = sum(is_cov_equal) + + # Ensure exactly one covariance matrix changes per step. + self.assertEqual(num_cov_equal, len(cov_matrices) - i) + + # Run all covariance update ops. + sess.run(cov_update_op) + sess.run(increment_global_step) + + def test_inv_update_thunks(self): + """Ensures inverse update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimator( + variables=[self.weights], + layer_collection=self.layer_collection, + cov_ema_decay=0.0, + damping=0.0) + + # Construct op that updates one inverse per global step. + global_step = training_util.get_or_create_global_step() + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._inverses_by_damping.values() + ] + inv_update_op_thunks = fisher_estimator.inv_update_thunks + inv_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(inv_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_inv_values = sess.run(inv_matrices) + + # Ensure there's one update per inverse matrix. This is true as long as + # there's no fan-in/fan-out or parameter re-use. + self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) + + # Test is no-op if only 1 invariance matrix. + assert len(inv_matrices) > 1 + + # Assign each covariance matrix a value other than the identity. This + # ensures that the inverse matrices are updated to something different as + # well. + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + sess.run([ + cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) + for cov_matrix in cov_matrices + ]) + + for i in range(len(inv_matrices)): + # Compare new and old inverse values + new_inv_values = sess.run(inv_matrices) + is_inv_equal = [ + np.allclose(initial_inv_value, new_inv_value) + for (initial_inv_value, + new_inv_value) in zip(initial_inv_values, new_inv_values) + ] + num_inv_equal = sum(is_inv_equal) + + # Ensure exactly one inverse matrix changes per step. + self.assertEqual(num_inv_equal, len(inv_matrices) - i) + + # Run all inverse update ops. + sess.run(inv_update_op) + sess.run(increment_global_step) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index 02b0677824..d66395ded7 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -66,7 +66,21 @@ class _DeviceContextGenerator(object): class FisherEstimator(object): - """Fisher estimator class supporting various approximations of the Fisher.""" + """Fisher estimator class supporting various approximations of the Fisher. + + Attributes: + cov_update_thunks: list of no-arg functions. Executing a function adds + covariance update ops for a single FisherFactor to the graph. + cov_update_ops: List of Ops. Running an op updates covariance matrices for a + single FisherFactor. + cov_update_op: Op. Running updates covariance matrices for all + FisherFactors. + inv_update_thunks: list of no-arg functions. Executing a function adds + inverse update ops for a single FisherFactor to the graph. + inv_update_ops: List of Ops. Running an op updates inverse matrices for a + single FisherFactor. + inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + """ def __init__(self, variables, @@ -122,6 +136,7 @@ class FisherEstimator(object): ValueError: If no losses have been registered with layer_collection. """ + self._cov_ema_decay = cov_ema_decay self._variables = variables self._damping = damping self._estimation_mode = estimation_mode @@ -135,13 +150,31 @@ class FisherEstimator(object): "exact": self._get_grads_lists_exact } self._colocate_gradients_with_ops = colocate_gradients_with_ops + + # TODO(b/70674513): Factor device placement outside of this class. self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) if inv_devices == cov_devices: self._inv_device_context_generator = self._cov_device_context_generator else: self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) - setup = self._setup(cov_ema_decay) - self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup + + self._instantiate_factors() + + self.cov_update_thunks = [ + self._create_cov_update_thunk(factor) + for factor in self._layers.get_factors() + ] + self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks] + self.cov_update_op = control_flow_ops.group( + self.cov_update_ops, name="cov_update_op") + + self.inv_update_thunks = [ + self._create_inv_update_thunk(factor) + for factor in self._layers.get_factors() + ] + self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks] + self.inv_update_op = control_flow_ops.group( + self.inv_update_ops, name="inv_update_op") @property def variables(self): @@ -202,18 +235,8 @@ class FisherEstimator(object): return self._apply_transformation(vecs_and_vars, lambda fb, vec: fb.multiply(vec)) - def _setup(self, cov_ema_decay): - """Sets up the various operations. - - Args: - cov_ema_decay: The decay factor used when calculating the covariance - estimate moving averages. - - Returns: - A triple (covs_update_op, invs_update_op, inv_updates_dict), where - covs_update_op is the grouped Op to update all the covariance estimates, - invs_update_op is the grouped Op to update all the inverses, and - inv_updates_dict is a dict mapping Op names to individual inverse updates. + def _instantiate_factors(self): + """Instantiates FisherFactors' variables. Raises: ValueError: If estimation_mode was improperly specified at construction. @@ -238,20 +261,30 @@ class FisherEstimator(object): with self._cov_device_context_generator(): fb.instantiate_factors(grads_list, self.damping) - cov_updates = [ - factor.make_covariance_update_op(cov_ema_decay) - for factor in self._layers.get_factors() - ] - inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} + def _create_cov_update_thunk(self, factor): + """Constructs a covariance update thunk for a single FisherFactor.""" - return control_flow_ops.group(*cov_updates), control_flow_ops.group( - *inv_updates.values()), inv_updates + def thunk(): + with tf_ops.name_scope( + "create_cov_update_thunk", values=[self._cov_ema_decay]): + return factor.make_covariance_update_op(self._cov_ema_decay) - def _get_all_inverse_update_ops(self): - for factor in self._layers.get_factors(): - with self._inv_device_context_generator(): - for op in factor.make_inverse_update_ops(): - yield op + return thunk + + def _create_inv_update_thunk(self, factor): + """Constructs an inverse update thunk for a single FisherFactor.""" + + def thunk(): + with tf_ops.name_scope("create_inv_update_thunk"): + with self._inv_device_context_generator(): + return control_flow_ops.group(factor.make_inverse_update_ops()) + + return thunk + + @property + def inv_updates_dict(self): + """Returns a dictionary mapping strings to inv_update_ops.""" + return {op.name: op for op in self.inv_update_ops} def _get_grads_lists_gradients(self, tensors): grads_flat = gradients_impl.gradients( |