aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-16 11:30:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-16 12:47:33 -0700
commit25ce56037a1b3ac3073547711da63629737f29fe (patch)
tree74d8a81f85df1833781f6b33f40c7087e71b4098
parent3505673a75b23c5f48983cf0b0230f0e7004d05e (diff)
Correct the bias introduced by default EMA initialization in score function
mean baseline. Change: 133414673
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py32
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py34
2 files changed, 56 insertions, 10 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
index 56936e6c38..e1edbc908c 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
@@ -61,18 +61,27 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
def testScoreFunctionWithMeanBaseline(self):
ema_decay = 0.8
+ num_steps = 6
x = st.BernoulliTensor(
p=self._p,
loss_fn=sge.get_score_function_with_baseline(
sge.get_mean_baseline(ema_decay)))
sf = x.loss(self._final_loss)
- expected = tf.log(self._p) * (self._final_loss -
- (1. - ema_decay) * self._final_loss)
+ # Expected EMA value
+ ema = 0.
+ for _ in range(num_steps):
+ ema -= (1. - ema_decay) * (ema - self._final_loss)
+
+ # Baseline is EMA with bias correction
+ bias_correction = 1. - ema_decay**num_steps
+ baseline = ema / bias_correction
+ expected = tf.log(self._p) * (self._final_loss - baseline)
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
- sess.run(sf) # run to update EMA
+ for _ in range(num_steps - 1):
+ sess.run(sf) # run to update EMA
self.assertAllClose(*sess.run([expected, sf]))
def testScoreFunctionWithAdvantageFn(self):
@@ -87,6 +96,23 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
self._testScoreFunction(
sge.get_score_function_with_advantage(advantage_fn), expected)
+ def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
+ ema_decay = 0.8
+ x = st.BernoulliTensor(
+ p=self._p,
+ loss_fn=sge.get_score_function_with_baseline(
+ sge.get_mean_baseline(ema_decay)))
+ y = st.BernoulliTensor(
+ p=self._p,
+ loss_fn=sge.get_score_function_with_baseline(
+ sge.get_mean_baseline(ema_decay)))
+ sf_x = x.loss(self._final_loss)
+ sf_y = y.loss(self._final_loss)
+ with self.test_session() as sess:
+ # Smoke test
+ sess.run(tf.initialize_all_variables())
+ sess.run([sf_x, sf_y])
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
index 7cb8ef06f9..a0d37f81ae 100644
--- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
+++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
@@ -59,6 +59,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import training
from tensorflow.python.util.all_util import make_all
@@ -164,12 +165,17 @@ def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"):
return score_function_with_baseline
-def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"):
+def get_mean_baseline(ema_decay=0.99, name=None):
"""ExponentialMovingAverage baseline.
+ EMA initializes to 0, which introduces a bias. This baseline implements the
+ bias correction term from Adam (section 3 of
+ https://arxiv.org/pdf/1412.6980v8.pdf), dividing by `1 - ema_decay^t`, where
+ `t` is the step count.
+
Args:
ema_decay: decay rate for the ExponentialMovingAverage.
- name: name to prepend ops with.
+ name: name for variable scope of the ExponentialMovingAverage.
Returns:
Callable baseline function that takes the `DistributionTensor` (unused) and
@@ -177,14 +183,28 @@ def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"):
"""
def mean_baseline(_, loss):
- with ops.name_scope(name):
- ema = training.ExponentialMovingAverage(decay=ema_decay)
+ with vs.variable_scope(name, default_name="MeanBaseline"):
reduced_loss = math_ops.reduce_mean(loss)
+
+ ema = training.ExponentialMovingAverage(decay=ema_decay)
update_op = ema.apply([reduced_loss])
+
+ # The bias correction term requires keeping track of how many times the
+ # EMA has been updated. Creating a variable here to do so. The global step
+ # is not used because it may or may not track exactly the number of times
+ # the EMA is updated.
+ ema_var = ema.average(reduced_loss)
+ assert ema_var is not None
+ with ops.colocate_with(ema_var):
+ num_updates = vs.get_variable(
+ "local_ema_step", initializer=0, trainable=False)
+ num_updates = num_updates.assign_add(1)
+ bias_correction = 1. - math_ops.pow(ema_decay, math_ops.cast(
+ num_updates, reduced_loss.dtype))
+
with ops.control_dependencies([update_op]):
- # TODO(rsepassi): Possibly implement the initialization bias correction
- # term from Adam (section 3 of https://arxiv.org/pdf/1412.6980v8.pdf).
- baseline = ema.average(reduced_loss)
+ baseline = ema.average(reduced_loss) / bias_correction
+
return baseline
return mean_baseline