aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-07 22:48:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 16:18:43 -0800
commit8024d0f0889ca0f5c47db70a69a30ce3a887d193 (patch)
tree101a255ddae70bfe0261f6bbdd3348c407ccc58a
parent55a26877c79ebc03a50d3f7da3e3695cca2c2a8f (diff)
Add ability to debias 0-initialized EMAs in `assign_moving_average`.
Change: 138481437
-rw-r--r--tensorflow/python/training/moving_averages.py92
-rw-r--r--tensorflow/python/training/moving_averages_test.py14
2 files changed, 98 insertions, 8 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index bdc7144014..37012067f4 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -29,7 +29,7 @@ from tensorflow.python.training import slot_creator
# TODO(touts): switch to variables.Variable.
-def assign_moving_average(variable, value, decay, name=None):
+def assign_moving_average(variable, value, decay, zero_debias=False, name=None):
"""Compute the moving average of a variable.
The moving average of 'variable' updated with 'value' is:
@@ -40,10 +40,20 @@ def assign_moving_average(variable, value, decay, name=None):
The new value of 'variable' can be set with the 'AssignSub' op as:
variable -= (1 - decay) * (variable - value)
+ Since variables that are initialized to a `0` value will be `0` biased,
+ `zero_debias` optionally enables scaling by the mathematically correct
+ debiasing factor of
+ 1 - decay ** num_updates
+ See `ADAM: A Method for Stochastic Optimization` Section 3 for more details
+ (https://arxiv.org/abs/1412.6980).
+
Args:
variable: A Variable.
- value: A tensor with the same shape as 'variable'
+ value: A tensor with the same shape as 'variable'.
decay: A float Tensor or float value. The moving average decay.
+ zero_debias: A python bool. If true, assume the variable is 0-intialized and
+ unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
+ `_zero_debias` for more details.
name: Optional name of the returned operation.
Returns:
@@ -56,9 +66,11 @@ def assign_moving_average(variable, value, decay, name=None):
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
- return state_ops.assign_sub(variable,
- (variable - value) * decay,
- name=scope)
+ if zero_debias:
+ update_delta = _zero_debias(variable, value, decay)
+ else:
+ update_delta = (variable - value) * decay
+ return state_ops.assign_sub(variable, update_delta, name=scope)
def weighted_moving_average(value,
@@ -121,6 +133,69 @@ def weighted_moving_average(value,
return math_ops.div(numerator, denominator, name=scope.name)
+def _zero_debias(unbiased_var, value, decay):
+ """Compute the delta required for a debiased Variable.
+
+ All exponential moving averages initialized with Tensors are initialized to 0,
+ and therefore are biased to 0. Variables initialized to 0 and used as EMAs are
+ similarly biased. This function creates the debias updated amount according to
+ a scale factor, as in https://arxiv.org/abs/1412.6980.
+
+ To demonstrate the bias the results from 0-initialization, take an EMA that
+ was initialized to `0` with decay `b`. After `t` timesteps of seeing the
+ constant `c`, the variable have the following value:
+
+ ```
+ EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ...
+ = c*(1 - b^t)
+ ```
+
+ To have the true value `c`, we would divide by the scale factor `1 - b^t`.
+
+ In order to perform debiasing, we use two shadow variables. One keeps track of
+ the biased estimate, and the other keeps track of the number of updates that
+ have occurred.
+
+ Args:
+ unbiased_var: A Variable representing the current value of the unbiased EMA.
+ value: A Tensor representing the most recent value.
+ decay: A Tensor representing `1-decay` for the EMA.
+
+ Returns:
+ The amount that the unbiased variable should be updated. Computing this
+ tensor will also update the shadow variables appropriately.
+ """
+ with variable_scope.variable_scope(
+ "ZeroDebias", values=[unbiased_var, value, decay]) as scope:
+ with ops.colocate_with(unbiased_var):
+ biased_var = variable_scope.get_variable(
+ unbiased_var.op.name + "_biased",
+ initializer=init_ops.zeros_initializer(
+ unbiased_var.get_shape(), dtype=unbiased_var.dtype),
+ trainable=False)
+ # Initializing the local_step to `0` would cause problems with the
+ # debiasing equation, so we instead initialize to `1`.
+ local_step = variable_scope.get_variable(
+ unbiased_var.op.name + "_local_step",
+ initializer=init_ops.ones_initializer([], dtype=unbiased_var.dtype),
+ trainable=False)
+
+ # Get an update ops for both shadow variables.
+ update_biased = state_ops.assign_sub(biased_var,
+ (biased_var - value) * decay,
+ name=scope.name)
+ update_local_step = local_step.assign_add(1)
+
+ # Compute the value of the delta to update the unbiased EMA. Make sure to
+ # use the new values of the biased variable and the local step.
+ with ops.control_dependencies([update_biased, update_local_step]):
+ # This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
+ unbiased_ema_delta = (unbiased_var - biased_var.ref() /
+ (1 - math_ops.pow(1.0 - decay, local_step.ref())))
+
+ return unbiased_ema_delta
+
+
class ExponentialMovingAverage(object):
"""Maintains moving averages of variables by employing an exponential decay.
@@ -216,7 +291,7 @@ class ExponentialMovingAverage(object):
ops to maintain moving averages.
The optional `num_updates` parameter allows one to tweak the decay rate
- dynamically. . It is typical to pass the count of training steps, usually
+ dynamically. It is typical to pass the count of training steps, usually
kept in a variable that is incremented at each step, in which case the
decay rate is lower at the start of training. This makes moving averages
move faster. If passed, the actual decay rate used is:
@@ -241,7 +316,8 @@ class ExponentialMovingAverage(object):
creates shadow variables for all elements of `var_list`. Shadow variables
for `Variable` objects are initialized to the variable's initial value.
They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
- For `Tensor` objects, the shadow variables are initialized to 0.
+ For `Tensor` objects, the shadow variables are initialized to 0 and zero
+ debiased (see docstring in `assign_moving_average` for more details).
shadow variables are created with `trainable=False` and added to the
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
@@ -315,7 +391,7 @@ class ExponentialMovingAverage(object):
Returns:
A `Variable` object or `None` if the moving average of `var`
- is not maintained..
+ is not maintained.
"""
return self._averages.get(var, None)
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 8766b9e3ef..79ba3a780d 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -38,6 +38,20 @@ class MovingAveragesTest(tf.test.TestCase):
11.0 * 0.25 + 2.0 * (1.0 - 0.25)],
var.eval())
+ def testAssignMovingAverageWithZeroDebias(self):
+ with self.test_session():
+ var = tf.Variable([0.0, 0.0])
+ val = tf.constant([1.0, 2.0], tf.float32)
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=True)
+ tf.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ assign.op.run()
+ self.assertAllClose([1.0 * (1.0 - 0.25) / (1 - 0.25 ** 2),
+ 2.0 * (1.0 - 0.25) / (1 - 0.25 ** 2)],
+ var.eval())
+
def testWeightedMovingAverage(self):
with self.test_session() as sess:
decay = 0.5