aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-21 13:20:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 13:26:10 -0800
commit6419fd98883cd051213f0daeaea465728cf7a27c (patch)
tree4f499bc6eccb4f2b08ac1bf00487ce530de7239c /tensorflow/contrib/kfac
parentc8ccab3bda96bbda7adc281eaf095390806b06d7 (diff)
K-FAC: LM algorithm for adapting damping, Example to train MNIST autoencoder model using variable size training data and update damping parameter, add KFACOptimizer.{update_damping}.
PiperOrigin-RevId: 186509305
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py26
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD3
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py13
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py250
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py29
5 files changed, 263 insertions, 58 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
index bfdb69ad02..b12f7be769 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -90,49 +90,51 @@ class EstimatorTest(test.TestCase):
def testEstimatorInitManualRegistration(self):
with self._graph.as_default():
# We should be able to build an estimator for only the registered vars.
- estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection)
+ estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
+ self.layer_collection)
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
+ estimator.FisherEstimator(lambda: 0.2, [self.weights, self.bias], 0.1,
self.layer_collection)
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
- estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
+ estimator.FisherEstimator(lambda: 0.2, [], 0.1, self.layer_collection)
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection)
+ estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
+ self.layer_collection)
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
- estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection,
- "not_a_real_mode")
+ estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
+ self.layer_collection, "not_a_real_mode")
def testModeListCorrect(self):
with self._graph.as_default():
- est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ est = estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
self.layer_collection)
self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys())
def testAllModesBuild(self):
for mode in _ALL_ESTIMATION_MODES:
with self._graph.as_default():
- estimator.FisherEstimator([self.weights], 0.1, 0.2,
+ estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
self.layer_collection, mode)
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimator(
+ damping_fn=lambda: 0.2,
variables=[self.weights],
layer_collection=self.layer_collection,
- cov_ema_decay=0.0,
- damping=0.0)
+ cov_ema_decay=0.0)
# Construct an op that executes one covariance update per step.
global_step = training_util.get_or_create_global_step()
@@ -176,10 +178,10 @@ class EstimatorTest(test.TestCase):
"""Ensures inverse update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimator(
+ damping_fn=lambda: 0.2,
variables=[self.weights],
layer_collection=self.layer_collection,
- cov_ema_decay=0.0,
- damping=0.0)
+ cov_ema_decay=0.0)
# Construct op that updates one inverse per global step.
global_step = training_util.get_or_create_global_step()
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index ee6549b109..c26230c2a8 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -144,10 +144,13 @@ py_library(
":fisher_estimator",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:state_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index a7b1f9d35c..a7e268c48a 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -83,9 +83,9 @@ class FisherEstimator(object):
"""
def __init__(self,
+ damping_fn,
variables,
cov_ema_decay,
- damping,
layer_collection,
estimation_mode="gradients",
colocate_gradients_with_ops=True,
@@ -94,16 +94,12 @@ class FisherEstimator(object):
"""Create a FisherEstimator object.
Args:
+ damping_fn: Function, accepts no arguments and returns damping value.
variables: A list of the variables for which to estimate the Fisher. This
must match the variables registered in layer_collection (if it is not
None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
- damping: The damping factor used to stabilize training due to errors in
- the local approximation with the Fisher information matrix, and to
- regularize the update direction by making it closer to the gradient.
- (Higher damping means the update looks more like a standard gradient
- update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the fisher
blocks, kronecker factors, and losses associated with the
graph.
@@ -135,10 +131,9 @@ class FisherEstimator(object):
Raises:
ValueError: If no losses have been registered with layer_collection.
"""
-
+ self._damping_fn = damping_fn
self._cov_ema_decay = cov_ema_decay
self._variables = variables
- self._damping = damping
self._estimation_mode = estimation_mode
self._layers = layer_collection
self._layers.create_subgraph()
@@ -182,7 +177,7 @@ class FisherEstimator(object):
@property
def damping(self):
- return self._damping
+ return self._damping_fn()
def _apply_transformation(self, vecs_and_vars, transform):
"""Applies an block-wise transformation to the corresponding vectors.
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index 1974b07acf..5d456bcb79 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -23,11 +23,14 @@ from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products
from tensorflow.contrib.kfac.python.ops import estimator as est
# pylint enable=long-line
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import gradient_descent
@@ -61,6 +64,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
damping: The damping factor used to stabilize training due to errors in
the local approximation with the Fisher information matrix, and to
regularize the update direction by making it closer to the gradient.
+ If damping is adapted during training then this value is used for
+ initializing damping varaible.
(Higher damping means the update looks more like a standard gradient
update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the fisher
@@ -105,10 +110,31 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
if variables is None:
variables = tf_variables.trainable_variables()
+ # The below paramaters are required only if damping needs to be adapated.
+ # These parameters can be set by calling
+ # set_damping_adaptation_params() explicitly.
+ self._damping_adaptation_decay = 0.95
+ self._damping_adaptation_interval = 5
+ # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
+ self._omega = (
+ self._damping_adaptation_decay**self._damping_adaptation_interval)
+ self._adapt_damping = False
+ self._min_damping = 1e-5
+ self._prev_train_batch = None
+ self._is_chief = False
+ self._loss_fn = None
+ self._damping_constant = damping
+ self._damping = None
+ self._rho = None
+ self._prev_loss = None
+ self._q_model_change = None
+ self._update_damping_op = None
+
+ self._layers = layer_collection
self._fisher_est = est.FisherEstimator(
+ lambda: self.damping,
variables,
cov_ema_decay,
- damping,
layer_collection,
estimation_mode=estimation_mode,
colocate_gradients_with_ops=colocate_gradients_with_ops,
@@ -139,6 +165,60 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
super(KfacOptimizer, self).__init__(learning_rate, name=name)
+ def set_damping_adaptation_params(self,
+ is_chief,
+ prev_train_batch,
+ loss_fn,
+ min_damping=1e-5,
+ damping_adaptation_decay=0.99,
+ damping_adaptation_interval=5):
+ """Sets parameters required to adapt damping during training.
+
+ When called, enables damping adaptation according to the Levenberg-Marquardt
+ style rule described in Section 6.5 of "Optimizing Neural Networks with
+ Kronecker-factored Approximate Curvature".
+
+ Args:
+ is_chief: `Boolean`, `True` if the worker is chief.
+ prev_train_batch: Training data used to minimize loss in the previous
+ step. This will be used to evaluate loss by calling
+ `loss_fn(prev_train_batch)`.
+ loss_fn: `function` that takes as input training data tensor and returns
+ a scalar loss.
+ min_damping: `float`(Optional), Minimum value the damping parameter
+ can take. Default value 1e-5.
+ damping_adaptation_decay: `float`(Optional), The `damping` parameter is
+ multipled by the `damping_adaptation_decay` every
+ `damping_adaptation_interval` number of iterations. Default value 0.99.
+ damping_adaptation_interval: `int`(Optional), Number of steps in between
+ updating the `damping` parameter. Default value 5.
+
+ Raises:
+ ValueError: If `set_damping_adaptation_params` is already called and the
+ the `adapt_damping` is `True`.
+ """
+ if self._adapt_damping:
+ raise ValueError("Damping adaptation parameters already set.")
+ with variable_scope.variable_scope(self.get_name()):
+ self._adapt_damping = True
+ self._is_chief = is_chief
+ self._prev_train_batch = prev_train_batch
+ self._loss_fn = loss_fn
+ self._damping_adaptation_decay = damping_adaptation_decay
+ self._damping_adaptation_interval = damping_adaptation_interval
+ self._omega = (
+ self._damping_adaptation_decay**self._damping_adaptation_interval)
+ self._min_damping = min_damping
+
+ self._rho = variable_scope.get_variable(
+ "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio.
+ self._prev_loss = variable_scope.get_variable(
+ "prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
+ self._q_model_change = variable_scope.get_variable(
+ "q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
+ self._damping = variable_scope.get_variable(
+ "damping", initializer=self._damping_constant, trainable=False)
+
@property
def cov_update_thunks(self):
return self._fisher_est.cov_update_thunks
@@ -169,14 +249,34 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
@property
def damping(self):
- return self._fisher_est.damping
+ if self._damping:
+ return self._damping
+ else:
+ return self._damping_constant
+
+ @property
+ def damping_adaptation_interval(self):
+ return self._damping_adaptation_interval
def minimize(self, *args, **kwargs):
kwargs["var_list"] = kwargs.get("var_list") or self.variables
if set(kwargs["var_list"]) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")
- return super(KfacOptimizer, self).minimize(*args, **kwargs)
+ if self._adapt_damping and self._is_chief:
+ global_step = kwargs.get("global_step", None)
+ if not global_step:
+ raise KeyError("global_step needs to be passed to optimizer.minimize "
+ "if damping parameter is adapted.")
+ update_damping_op = self._update_damping(self._prev_train_batch,
+ global_step)
+ with ops.control_dependencies([update_damping_op]):
+ loss = args[0]
+ loss_assign_op = state_ops.assign(self._prev_loss, loss)
+ train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
+ return control_flow_ops.group(loss_assign_op, train_op)
+ else:
+ return super(KfacOptimizer, self).minimize(*args, **kwargs)
def compute_gradients(self, *args, **kwargs):
# args[1] could be our var_list
@@ -296,6 +396,20 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]
+ def _compute_prev_updates(self, variables):
+ """Computes previous updates as negative velocities scaled by learning rate.
+
+ Args:
+ variables: List of variables in the graph that the update will be
+ applied to.
+
+ Returns:
+ List of previous updates applied to the `variables`.
+ """
+ return list(
+ -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
+ for var in variables)
+
def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
variables):
"""Compute optimal update hyperparameters from the quadratic model.
@@ -374,9 +488,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
[_inner_product_list(grads, prev_updates)]])
- sol = _two_by_two_solve(m, c)
- alpha = -sol[0]
- mu = -sol[1]
+ sol = -1. * _two_by_two_solve(m, c)
+ alpha = sol[0]
+ mu = sol[1]
qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)
return alpha, mu, qmodel_change
@@ -404,6 +518,52 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
return control_flow_ops.cond(
math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)
+ def _assign_q_model_change(self, q_model_change):
+ """Assigns `q_model_change` to `self._q_model_change` if damping is adapted.
+
+ Note only the chief worker does the assignment.
+
+ Args:
+ q_model_change: Scalar tensor of type `float32`.
+
+ Returns:
+ If `adapt_damping` is `True` then returns an assign op, Otherwise returns
+ a no_op().
+ """
+ if self._adapt_damping and self._is_chief:
+ q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
+ else:
+ q_model_assign_op = control_flow_ops.no_op()
+ return q_model_assign_op
+
+ def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
+ precon_grads_and_vars):
+ """Wrapper function for `self._compute_qmodel_hyperparams`.
+
+ Constructs a list of preconditioned gradients and variables. Also creates a
+ op to asssign the computed q model change to `self._q_model_change`.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs.
+ precon_grads_and_vars: List of (preconditioned gradients, variable)
+ pairs.
+
+ Returns:
+ (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
+ the quadratic model, `q_model_assign_op` assigns the computed q model
+ change to `self._q_model_change`.
+ """
+ precon_grads = list(
+ precon_grad for (precon_grad, _) in precon_grads_and_vars)
+ grads = list(grad for (grad, _) in grads_and_vars)
+ variables = list(var for (_, var) in grads_and_vars)
+ prev_updates = self._compute_prev_updates(variables)
+ # Compute optimal velocity update parameters according to quadratic model
+ alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
+ precon_grads, prev_updates, grads, variables)
+
+ return alpha, mu, self._assign_q_model_change(q_model_change)
+
def _compute_update_steps(self, grads_and_vars):
"""Computes the update steps for the variables given the gradients.
@@ -411,8 +571,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
grads_and_vars: List of (gradient, variable) pairs.
Returns:
- An 'Operation that computes the update steps for the given variables.
+ A list of tuple (assign_op ,var) where `assign_op` assigns the update
+ steps to `var`.
"""
+
if self._momentum_type == "regular":
# Compute "preconditioned" gradient.
precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
@@ -423,8 +585,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
precon_grads_and_vars)
# Update the velocity with this and return it as the step.
- return self._update_velocities(precon_grads_and_vars, self._momentum)
-
+ if self._adapt_damping and self._is_chief:
+ _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
+ grads_and_vars, precon_grads_and_vars)
+ with ops.control_dependencies([q_model_assign_op]):
+ return self._update_velocities(precon_grads_and_vars, self._momentum)
+ else:
+ return self._update_velocities(precon_grads_and_vars, self._momentum)
elif self._momentum_type == "adam":
# Update velocity.
velocities_and_vars = self._update_velocities(grads_and_vars,
@@ -436,23 +603,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
# Compute "preconditioned" gradient.
precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
- # Extract out singleton lists from the tuple-lists
- precon_grads = list(
- precon_grad for (precon_grad, _) in precon_grads_and_vars)
- grads = list(grad for (grad, _) in grads_and_vars)
- variables = list(var for (_, var) in grads_and_vars)
- # previous updates are the negative velocities (up to scaling by LR)
- prev_updates = list(
- -self._zeros_slot(var, "velocity", self._name) for var in variables)
-
# Compute optimal velocity update parameters according to quadratic model
- alpha, mu, _ = self._compute_qmodel_hyperparams(
- precon_grads, prev_updates, grads, variables)
+ alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
+ grads_and_vars, precon_grads_and_vars)
- # Update the velocity with precon_grads according to these params
- # and return it as the step.
- return self._update_velocities(
- precon_grads_and_vars, mu, vec_coeff=-alpha)
+ with ops.control_dependencies([q_model_assign_op]):
+ return self._update_velocities(
+ precon_grads_and_vars, mu, vec_coeff=-alpha)
def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
"""Updates the velocities of the variables with the given vectors.
@@ -482,6 +639,51 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
# Go through variable and update its associated part of the velocity vector.
return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
+ # TODO(b/73448937): Move all update damping code to a separate class/function.
+ def _update_damping(self, prev_batch, global_step):
+ """Adapts damping parameter. Check KFAC (Section 6.5) for the details.
+
+ The damping parameter is updated according to the Levenberg-Marquardt rule
+ every `self._damping_adaptation_interval` iterations.
+
+ Args:
+ prev_batch: Tensor or tuple of tensors which can be passed to
+ `self._loss_fn` to evaluate loss.
+ global_step: `Variable` which keeps track of number of times the training
+ variables have been updated.
+ Returns:
+ A `tf.cond` op which updates the damping parameter.
+ """
+ def compute_damping():
+ """"Adapts damping parameter based on "reduction ratio".
+
+ Reduction ratio captures how closely the quadratic approximation to the
+ loss function approximates the actual loss within a trust region. The
+ damping update tries to make the damping as small as possible while
+ maintaining the property that the quadratic model remains a good local
+ approximation to the loss function.
+
+ Returns:
+ An Op to assign newly computed damping value to `self._damping`.
+ """
+ prev_batch_loss = self._loss_fn(prev_batch)
+ with ops.control_dependencies([prev_batch_loss]):
+ rho_assign = self._rho.assign(
+ (prev_batch_loss - self._prev_loss) / self._q_model_change)
+ with ops.control_dependencies([rho_assign]):
+ new_damping = control_flow_ops.case(
+ [(self._rho < 0.25, lambda: self.damping / self._omega),
+ (self._rho > 0.75, lambda: self.damping * self._omega)],
+ lambda: self.damping)
+ with ops.control_dependencies([new_damping]):
+ new_damping_min = math_ops.maximum(new_damping, self._min_damping)
+ return control_flow_ops.group(self._damping.assign(new_damping_min))
+
+ return control_flow_ops.cond(
+ math_ops.equal(
+ math_ops.mod(global_step + 1, self._damping_adaptation_interval),
+ 0), compute_damping, control_flow_ops.no_op)
+
def _inner_product_list(list1, list2):
return math_ops.add_n(
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index f5bd97cb4e..88e6fb20e8 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -241,19 +241,22 @@ class SubGraph(object):
# Set of all ancestor Tensors, Ops to 'outputs'.
self._members = set()
- self._recurse_add(outputs)
-
- def _recurse_add(self, nodes):
- """Recursively adds all of nodes' ancestors."""
- for node in nodes:
- if node in self._members:
- continue
- self._members.add(node)
-
- if isinstance(node, ops.Tensor):
- self._recurse_add((node.op,))
- elif isinstance(node, ops.Operation):
- self._recurse_add(node.inputs)
+ self._iter_add(outputs)
+
+ def _iter_add(self, root):
+ """Iteratively adds all of nodes' ancestors using depth first search."""
+ stack = [root]
+ while stack:
+ nodes = stack.pop()
+ for node in nodes:
+ if node in self._members:
+ continue
+ self._members.add(node)
+
+ if isinstance(node, ops.Tensor):
+ stack.append((node.op,))
+ elif isinstance(node, ops.Operation):
+ stack.append(node.inputs)
def is_member(self, node):
"""Check if 'node' is in this subgraph."""