aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-11 15:39:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 15:43:28 -0800
commit9322010d25f0b11c8bc2a498672f270708697425 (patch)
tree2c0d1c59984bc5735593f2a428b7cb80832def8b /tensorflow/contrib/kfac
parent733a8ed8918a548867b110c452b67c95dda537e2 (diff)
K-FAC: Add (cov|inv)_update_(ops|thunks) to FisherEstimator.
PiperOrigin-RevId: 181672525
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD5
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py114
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py87
3 files changed, 179 insertions, 27 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index 4928bf2c10..17458ffa2a 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -17,12 +17,17 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
+ "//tensorflow/python:training",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index 9b28c45c72..bfdb69ad02 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.kfac.python.ops import estimator
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import utils
@@ -25,11 +27,15 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import training_util
_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
@@ -119,6 +125,114 @@ class EstimatorTest(test.TestCase):
estimator.FisherEstimator([self.weights], 0.1, 0.2,
self.layer_collection, mode)
+ def test_cov_update_thunks(self):
+ """Ensures covariance update ops run once per global_step."""
+ with self._graph.as_default(), self.test_session() as sess:
+ fisher_estimator = estimator.FisherEstimator(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ cov_ema_decay=0.0,
+ damping=0.0)
+
+ # Construct an op that executes one covariance update per step.
+ global_step = training_util.get_or_create_global_step()
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ cov_update_op_thunks = fisher_estimator.cov_update_thunks
+ cov_update_op = control_flow_ops.case(
+ [(math_ops.equal(global_step, i), thunk)
+ for i, thunk in enumerate(cov_update_op_thunks)])
+ increment_global_step = global_step.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+ initial_cov_values = sess.run(cov_matrices)
+
+ # Ensure there's one update per covariance matrix.
+ self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
+
+ # Test is no-op if only 1 covariance matrix.
+ assert len(cov_matrices) > 1
+
+ for i in range(len(cov_matrices)):
+ # Compare new and old covariance values
+ new_cov_values = sess.run(cov_matrices)
+ is_cov_equal = [
+ np.allclose(initial_cov_value, new_cov_value)
+ for (initial_cov_value,
+ new_cov_value) in zip(initial_cov_values, new_cov_values)
+ ]
+ num_cov_equal = sum(is_cov_equal)
+
+ # Ensure exactly one covariance matrix changes per step.
+ self.assertEqual(num_cov_equal, len(cov_matrices) - i)
+
+ # Run all covariance update ops.
+ sess.run(cov_update_op)
+ sess.run(increment_global_step)
+
+ def test_inv_update_thunks(self):
+ """Ensures inverse update ops run once per global_step."""
+ with self._graph.as_default(), self.test_session() as sess:
+ fisher_estimator = estimator.FisherEstimator(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ cov_ema_decay=0.0,
+ damping=0.0)
+
+ # Construct op that updates one inverse per global step.
+ global_step = training_util.get_or_create_global_step()
+ inv_matrices = [
+ matrix
+ for fisher_factor in self.layer_collection.get_factors()
+ for matrix in fisher_factor._inverses_by_damping.values()
+ ]
+ inv_update_op_thunks = fisher_estimator.inv_update_thunks
+ inv_update_op = control_flow_ops.case(
+ [(math_ops.equal(global_step, i), thunk)
+ for i, thunk in enumerate(inv_update_op_thunks)])
+ increment_global_step = global_step.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+ initial_inv_values = sess.run(inv_matrices)
+
+ # Ensure there's one update per inverse matrix. This is true as long as
+ # there's no fan-in/fan-out or parameter re-use.
+ self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
+
+ # Test is no-op if only 1 invariance matrix.
+ assert len(inv_matrices) > 1
+
+ # Assign each covariance matrix a value other than the identity. This
+ # ensures that the inverse matrices are updated to something different as
+ # well.
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ sess.run([
+ cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
+ for cov_matrix in cov_matrices
+ ])
+
+ for i in range(len(inv_matrices)):
+ # Compare new and old inverse values
+ new_inv_values = sess.run(inv_matrices)
+ is_inv_equal = [
+ np.allclose(initial_inv_value, new_inv_value)
+ for (initial_inv_value,
+ new_inv_value) in zip(initial_inv_values, new_inv_values)
+ ]
+ num_inv_equal = sum(is_inv_equal)
+
+ # Ensure exactly one inverse matrix changes per step.
+ self.assertEqual(num_inv_equal, len(inv_matrices) - i)
+
+ # Run all inverse update ops.
+ sess.run(inv_update_op)
+ sess.run(increment_global_step)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index 02b0677824..d66395ded7 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -66,7 +66,21 @@ class _DeviceContextGenerator(object):
class FisherEstimator(object):
- """Fisher estimator class supporting various approximations of the Fisher."""
+ """Fisher estimator class supporting various approximations of the Fisher.
+
+ Attributes:
+ cov_update_thunks: list of no-arg functions. Executing a function adds
+ covariance update ops for a single FisherFactor to the graph.
+ cov_update_ops: List of Ops. Running an op updates covariance matrices for a
+ single FisherFactor.
+ cov_update_op: Op. Running updates covariance matrices for all
+ FisherFactors.
+ inv_update_thunks: list of no-arg functions. Executing a function adds
+ inverse update ops for a single FisherFactor to the graph.
+ inv_update_ops: List of Ops. Running an op updates inverse matrices for a
+ single FisherFactor.
+ inv_update_op: Op. Running updates inverse matrices for all FisherFactors.
+ """
def __init__(self,
variables,
@@ -122,6 +136,7 @@ class FisherEstimator(object):
ValueError: If no losses have been registered with layer_collection.
"""
+ self._cov_ema_decay = cov_ema_decay
self._variables = variables
self._damping = damping
self._estimation_mode = estimation_mode
@@ -135,13 +150,31 @@ class FisherEstimator(object):
"exact": self._get_grads_lists_exact
}
self._colocate_gradients_with_ops = colocate_gradients_with_ops
+
+ # TODO(b/70674513): Factor device placement outside of this class.
self._cov_device_context_generator = _DeviceContextGenerator(cov_devices)
if inv_devices == cov_devices:
self._inv_device_context_generator = self._cov_device_context_generator
else:
self._inv_device_context_generator = _DeviceContextGenerator(inv_devices)
- setup = self._setup(cov_ema_decay)
- self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup
+
+ self._instantiate_factors()
+
+ self.cov_update_thunks = [
+ self._create_cov_update_thunk(factor)
+ for factor in self._layers.get_factors()
+ ]
+ self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks]
+ self.cov_update_op = control_flow_ops.group(
+ self.cov_update_ops, name="cov_update_op")
+
+ self.inv_update_thunks = [
+ self._create_inv_update_thunk(factor)
+ for factor in self._layers.get_factors()
+ ]
+ self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks]
+ self.inv_update_op = control_flow_ops.group(
+ self.inv_update_ops, name="inv_update_op")
@property
def variables(self):
@@ -202,18 +235,8 @@ class FisherEstimator(object):
return self._apply_transformation(vecs_and_vars,
lambda fb, vec: fb.multiply(vec))
- def _setup(self, cov_ema_decay):
- """Sets up the various operations.
-
- Args:
- cov_ema_decay: The decay factor used when calculating the covariance
- estimate moving averages.
-
- Returns:
- A triple (covs_update_op, invs_update_op, inv_updates_dict), where
- covs_update_op is the grouped Op to update all the covariance estimates,
- invs_update_op is the grouped Op to update all the inverses, and
- inv_updates_dict is a dict mapping Op names to individual inverse updates.
+ def _instantiate_factors(self):
+ """Instantiates FisherFactors' variables.
Raises:
ValueError: If estimation_mode was improperly specified at construction.
@@ -238,20 +261,30 @@ class FisherEstimator(object):
with self._cov_device_context_generator():
fb.instantiate_factors(grads_list, self.damping)
- cov_updates = [
- factor.make_covariance_update_op(cov_ema_decay)
- for factor in self._layers.get_factors()
- ]
- inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()}
+ def _create_cov_update_thunk(self, factor):
+ """Constructs a covariance update thunk for a single FisherFactor."""
- return control_flow_ops.group(*cov_updates), control_flow_ops.group(
- *inv_updates.values()), inv_updates
+ def thunk():
+ with tf_ops.name_scope(
+ "create_cov_update_thunk", values=[self._cov_ema_decay]):
+ return factor.make_covariance_update_op(self._cov_ema_decay)
- def _get_all_inverse_update_ops(self):
- for factor in self._layers.get_factors():
- with self._inv_device_context_generator():
- for op in factor.make_inverse_update_ops():
- yield op
+ return thunk
+
+ def _create_inv_update_thunk(self, factor):
+ """Constructs an inverse update thunk for a single FisherFactor."""
+
+ def thunk():
+ with tf_ops.name_scope("create_inv_update_thunk"):
+ with self._inv_device_context_generator():
+ return control_flow_ops.group(factor.make_inverse_update_ops())
+
+ return thunk
+
+ @property
+ def inv_updates_dict(self):
+ """Returns a dictionary mapping strings to inv_update_ops."""
+ return {op.name: op for op in self.inv_update_ops}
def _get_grads_lists_gradients(self, tensors):
grads_flat = gradients_impl.gradients(