aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-11 13:10:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 13:27:38 -0800
commita2cbdbabac342b3a05624ec42b56d9d87b4444a8 (patch)
tree6ab122a9a8b63d9341b19b39021be038194ffcc0
parenta8967c15a45be5517dec8c2c343f84e36b001b7b (diff)
Use `zero_debias=True` for Tensors in ExponentialMovingAverages class.
Change: 138910410
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py22
-rw-r--r--tensorflow/python/training/moving_averages.py12
-rw-r--r--tensorflow/python/training/moving_averages_test.py41
3 files changed, 36 insertions, 39 deletions
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
index fe9866862e..2139419289 100644
--- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
+++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
@@ -168,11 +168,6 @@ def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"):
def get_mean_baseline(ema_decay=0.99, name=None):
"""ExponentialMovingAverage baseline.
- EMA initializes to 0, which introduces a bias. This baseline implements the
- bias correction term from Adam (section 3 of
- https://arxiv.org/pdf/1412.6980v8.pdf), dividing by `1 - ema_decay^t`, where
- `t` is the step count.
-
Args:
ema_decay: decay rate for the ExponentialMovingAverage.
name: name for variable scope of the ExponentialMovingAverage.
@@ -189,21 +184,10 @@ def get_mean_baseline(ema_decay=0.99, name=None):
ema = training.ExponentialMovingAverage(decay=ema_decay)
update_op = ema.apply([reduced_loss])
- # The bias correction term requires keeping track of how many times the
- # EMA has been updated. Creating a variable here to do so. The global step
- # is not used because it may or may not track exactly the number of times
- # the EMA is updated.
- ema_var = ema.average(reduced_loss)
- assert ema_var is not None
- with ops.colocate_with(ema_var):
- num_updates = vs.get_variable(
- "local_ema_step", initializer=0, trainable=False)
- num_updates = num_updates.assign_add(1)
- bias_correction = 1. - math_ops.pow(ema_decay, math_ops.cast(
- num_updates, reduced_loss.dtype))
-
with ops.control_dependencies([update_op]):
- baseline = ema.average(reduced_loss) / bias_correction
+ # Using `identity` causes an op to be added in this context, which
+ # triggers the update. Removing the `identity` means nothing is updated.
+ baseline = array_ops.identity(ema.average(reduced_loss))
return baseline
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 5e14309ea8..1b5fb019e2 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -166,17 +166,17 @@ def _zero_debias(unbiased_var, value, decay):
tensor will also update the shadow variables appropriately.
"""
with variable_scope.variable_scope(
- "ZeroDebias", values=[unbiased_var, value, decay]) as scope:
+ unbiased_var.op.name, values=[unbiased_var, value, decay]) as scope:
with ops.colocate_with(unbiased_var):
biased_var = variable_scope.get_variable(
- unbiased_var.op.name + "_biased",
+ "biased",
initializer=init_ops.zeros_initializer(
unbiased_var.get_shape(), dtype=unbiased_var.dtype),
trainable=False)
# Initializing the local_step to `0` would cause problems with the
# debiasing equation, so we instead initialize to `1`.
local_step = variable_scope.get_variable(
- name=unbiased_var.op.name + "_local_step",
+ "local_step",
shape=[], dtype=unbiased_var.dtype,
initializer=init_ops.ones_initializer(),
trainable=False)
@@ -344,6 +344,7 @@ class ExponentialMovingAverage(object):
# TODO(touts): op_scope
if var_list is None:
var_list = variables.trainable_variables()
+ zero_debias_true = set() # set of vars to set `zero_debias=True`
for var in var_list:
if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32,
dtypes.float64]:
@@ -369,6 +370,7 @@ class ExponentialMovingAverage(object):
var,
self._name,
colocate_with_primary=(var.op.type == "Variable"))
+ zero_debias_true.add(avg)
self._averages[var] = avg
with ops.name_scope(self._name) as scope:
@@ -381,7 +383,9 @@ class ExponentialMovingAverage(object):
(1.0 + num_updates) / (10.0 + num_updates))
updates = []
for var in var_list:
- updates.append(assign_moving_average(self._averages[var], var, decay))
+ zero_debias = self._averages[var] in zero_debias_true
+ updates.append(assign_moving_average(
+ self._averages[var], var, decay, zero_debias=zero_debias))
return control_flow_ops.group(*updates, name=scope)
def average(self, var):
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 79ba3a780d..ebcd3a34cc 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -133,7 +133,7 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
self.assertAllClose(expected, avg0.eval())
expected = _Repeat(30.0 * dk + 30.0 * (1 - dk), dim)
self.assertAllClose(expected, avg1.eval())
- expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk), dim)
+ expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk) / (1 - dk ** 2), dim)
self.assertAllClose(expected, avg2.eval())
# Again, update the averages and check.
@@ -145,7 +145,7 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
dim)
self.assertAllClose(expected, avg1.eval())
expected = _Repeat(((0.0 * dk + (10.0 + 30.0) * (1 - dk)) * dk +
- (10.0 + 30.0) * (1 - dk)),
+ (10.0 + 30.0) * (1 - dk)) / (1 - dk ** 3),
dim)
self.assertAllClose(expected, avg2.eval())
@@ -202,21 +202,25 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
# Add a non-trainable variable.
v2 = tf.Variable(20.0, name="v2", trainable=False)
tensor2 = v0 + v1
- ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
- self.assertEqual("v0/foo_avg", ema.average_name(v0))
- self.assertEqual("v1/foo_avg", ema.average_name(v1))
- self.assertEqual("add/foo_avg", ema.average_name(tensor2))
+ ema = tf.train.ExponentialMovingAverage(0.25, name="foo")
+ self.assertEqual("v0/foo", ema.average_name(v0))
+ self.assertEqual("v1/foo", ema.average_name(v1))
+ self.assertEqual("add/foo", ema.average_name(tensor2))
ema.apply([v0, v1, tensor2])
vars_to_restore = ema.variables_to_restore()
# vars_to_restore should contain the following:
- # {v0/foo_avg : v0,
- # v1/foo_avg : v1,
- # add/foo_avg : add/foo_avg
+ # {v0/foo : v0,
+ # v1/foo : v1,
+ # add/foo : add/foo,
+ # add/foo/biased: add/foo/biased,
+ # add/foo/local_step: add/foo/local_step,
# v2 : v2}
self.assertEqual(sorted(vars_to_restore.keys()),
sorted([ema.average_name(v0),
ema.average_name(v1),
ema.average_name(tensor2),
+ ema.average_name(tensor2) + "/biased",
+ ema.average_name(tensor2) + "/local_step",
v2.op.name]))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
@@ -232,21 +236,26 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
v2 = tf.Variable(20.0, name="v2", trainable=False)
tensor2 = v0 + v1
with tf.variable_scope("scope2"):
- ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
- self.assertEqual("scope2/scope1/v0/foo_avg", ema.average_name(v0))
- self.assertEqual("scope2/scope1/v1/foo_avg", ema.average_name(v1))
- self.assertEqual("scope2/scope1/add/foo_avg", ema.average_name(tensor2))
+ ema = tf.train.ExponentialMovingAverage(0.25, name="foo")
+ self.assertEqual("scope2/scope1/v0/foo", ema.average_name(v0))
+ self.assertEqual("scope2/scope1/v1/foo", ema.average_name(v1))
+ self.assertEqual("scope2/scope1/add/foo", ema.average_name(tensor2))
ema.apply([v0, v1, tensor2])
vars_to_restore = ema.variables_to_restore()
# vars_to_restore should contain the following:
- # {scope2/scope1/v0/foo_avg : v0,
- # scope2/scope1/v1/foo_avg : v1,
- # scope2/scope1/add/foo_avg : add/foo_avg
+ # {scope2/scope1/v0/foo : v0,
+ # scope2/scope1/v1/foo : v1,
+ # scope2/scope1/add/foo : add/foo,
+ # scope2/scope2/scope1/add/foo/biased: add/foo/biased,
+ # scope2/scope2/scope1/add/foo/local_step: add/foo/local_step,
# scope1/v2 : v2}
+ sc = "scope2/"
self.assertEqual(sorted(vars_to_restore.keys()),
sorted([ema.average_name(v0),
ema.average_name(v1),
ema.average_name(tensor2),
+ sc + ema.average_name(tensor2) + "/biased",
+ sc + ema.average_name(tensor2) + "/local_step",
v2.op.name]))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)