diff options
author | 2016-09-16 11:30:39 -0800 | |
---|---|---|
committer | 2016-09-16 12:47:33 -0700 | |
commit | 25ce56037a1b3ac3073547711da63629737f29fe (patch) | |
tree | 74d8a81f85df1833781f6b33f40c7087e71b4098 | |
parent | 3505673a75b23c5f48983cf0b0230f0e7004d05e (diff) |
Correct the bias introduced by default EMA initialization in score function
mean baseline.
Change: 133414673
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py | 32 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py | 34 |
2 files changed, 56 insertions, 10 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py index 56936e6c38..e1edbc908c 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py @@ -61,18 +61,27 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase): def testScoreFunctionWithMeanBaseline(self): ema_decay = 0.8 + num_steps = 6 x = st.BernoulliTensor( p=self._p, loss_fn=sge.get_score_function_with_baseline( sge.get_mean_baseline(ema_decay))) sf = x.loss(self._final_loss) - expected = tf.log(self._p) * (self._final_loss - - (1. - ema_decay) * self._final_loss) + # Expected EMA value + ema = 0. + for _ in range(num_steps): + ema -= (1. - ema_decay) * (ema - self._final_loss) + + # Baseline is EMA with bias correction + bias_correction = 1. - ema_decay**num_steps + baseline = ema / bias_correction + expected = tf.log(self._p) * (self._final_loss - baseline) with self.test_session() as sess: sess.run(tf.initialize_all_variables()) - sess.run(sf) # run to update EMA + for _ in range(num_steps - 1): + sess.run(sf) # run to update EMA self.assertAllClose(*sess.run([expected, sf])) def testScoreFunctionWithAdvantageFn(self): @@ -87,6 +96,23 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase): self._testScoreFunction( sge.get_score_function_with_advantage(advantage_fn), expected) + def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self): + ema_decay = 0.8 + x = st.BernoulliTensor( + p=self._p, + loss_fn=sge.get_score_function_with_baseline( + sge.get_mean_baseline(ema_decay))) + y = st.BernoulliTensor( + p=self._p, + loss_fn=sge.get_score_function_with_baseline( + sge.get_mean_baseline(ema_decay))) + sf_x = x.loss(self._final_loss) + sf_y = y.loss(self._final_loss) + with self.test_session() as sess: + # Smoke test + sess.run(tf.initialize_all_variables()) + sess.run([sf_x, sf_y]) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py index 7cb8ef06f9..a0d37f81ae 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py +++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py @@ -59,6 +59,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import training from tensorflow.python.util.all_util import make_all @@ -164,12 +165,17 @@ def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"): return score_function_with_baseline -def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"): +def get_mean_baseline(ema_decay=0.99, name=None): """ExponentialMovingAverage baseline. + EMA initializes to 0, which introduces a bias. This baseline implements the + bias correction term from Adam (section 3 of + https://arxiv.org/pdf/1412.6980v8.pdf), dividing by `1 - ema_decay^t`, where + `t` is the step count. + Args: ema_decay: decay rate for the ExponentialMovingAverage. - name: name to prepend ops with. + name: name for variable scope of the ExponentialMovingAverage. Returns: Callable baseline function that takes the `DistributionTensor` (unused) and @@ -177,14 +183,28 @@ def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"): """ def mean_baseline(_, loss): - with ops.name_scope(name): - ema = training.ExponentialMovingAverage(decay=ema_decay) + with vs.variable_scope(name, default_name="MeanBaseline"): reduced_loss = math_ops.reduce_mean(loss) + + ema = training.ExponentialMovingAverage(decay=ema_decay) update_op = ema.apply([reduced_loss]) + + # The bias correction term requires keeping track of how many times the + # EMA has been updated. Creating a variable here to do so. The global step + # is not used because it may or may not track exactly the number of times + # the EMA is updated. + ema_var = ema.average(reduced_loss) + assert ema_var is not None + with ops.colocate_with(ema_var): + num_updates = vs.get_variable( + "local_ema_step", initializer=0, trainable=False) + num_updates = num_updates.assign_add(1) + bias_correction = 1. - math_ops.pow(ema_decay, math_ops.cast( + num_updates, reduced_loss.dtype)) + with ops.control_dependencies([update_op]): - # TODO(rsepassi): Possibly implement the initialization bias correction - # term from Adam (section 3 of https://arxiv.org/pdf/1412.6980v8.pdf). - baseline = ema.average(reduced_loss) + baseline = ema.average(reduced_loss) / bias_correction + return baseline return mean_baseline |