diff options
author | 2016-09-06 09:41:47 -0800 | |
---|---|---|
committer | 2016-09-06 10:48:04 -0700 | |
commit | a4f6a51a84a245d7109a1a256d10785bd8829843 (patch) | |
tree | e8b9f91041755f5adeb143340b1ffa8af95e30a0 /tensorflow | |
parent | 8c71e3c554a05f689e887f809c1bbb8d168d47c9 (diff) |
Add tests for stochastic gradient estimators and fix EMA usage bug
Change: 132336601
Diffstat (limited to 'tensorflow')
3 files changed, 108 insertions, 4 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 37ce250843..84bae2f768 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -90,6 +90,18 @@ cuda_py_test( ) cuda_py_test( + name = "stochastic_gradient_estimators_test", + size = "medium", + srcs = ["python/kernel_tests/stochastic_gradient_estimators_test.py"], + additional_deps = [ + ":bayesflow_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( name = "reinforce_simple_example", size = "small", srcs = ["examples/reinforce_simple/reinforce_simple_example.py"], diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py new file mode 100644 index 0000000000..56936e6c38 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py @@ -0,0 +1,92 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stochastic graphs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +st = tf.contrib.bayesflow.stochastic_tensor +sge = tf.contrib.bayesflow.stochastic_gradient_estimators + + +class StochasticGradientEstimatorsTest(tf.test.TestCase): + + def setUp(self): + self._p = tf.constant(0.999999) + self._final_loss = tf.constant(3.2) + + def _testScoreFunction(self, loss_fn, expected): + x = st.BernoulliTensor(p=self._p, loss_fn=loss_fn) + sf = x.loss(self._final_loss) + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + self.assertAllClose(*sess.run([expected, sf])) + + def testScoreFunction(self): + expected = tf.log(self._p) * self._final_loss + self._testScoreFunction(sge.score_function, expected) + + def testScoreFunctionWithConstantBaseline(self): + b = tf.constant(9.8) + expected = tf.log(self._p) * (self._final_loss - b) + self._testScoreFunction( + sge.get_score_function_with_constant_baseline(b), expected) + + def testScoreFunctionWithBaselineFn(self): + b = tf.constant(9.8) + + def baseline_fn(stoch_tensor, loss): + self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor)) + self.assertTrue(isinstance(loss, tf.Tensor)) + return b + + expected = tf.log(self._p) * (self._final_loss - b) + self._testScoreFunction( + sge.get_score_function_with_baseline(baseline_fn), expected) + + def testScoreFunctionWithMeanBaseline(self): + ema_decay = 0.8 + x = st.BernoulliTensor( + p=self._p, + loss_fn=sge.get_score_function_with_baseline( + sge.get_mean_baseline(ema_decay))) + sf = x.loss(self._final_loss) + + expected = tf.log(self._p) * (self._final_loss - + (1. - ema_decay) * self._final_loss) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + sess.run(sf) # run to update EMA + self.assertAllClose(*sess.run([expected, sf])) + + def testScoreFunctionWithAdvantageFn(self): + b = tf.constant(9.8) + + def advantage_fn(stoch_tensor, loss): + self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor)) + self.assertTrue(isinstance(loss, tf.Tensor)) + return loss - b + + expected = tf.log(self._p) * (self._final_loss - b) + self._testScoreFunction( + sge.get_score_function_with_advantage(advantage_fn), expected) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py index 62b8446131..7cb8ef06f9 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py +++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py @@ -58,7 +58,6 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import training from tensorflow.python.util.all_util import make_all @@ -180,11 +179,12 @@ def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"): def mean_baseline(_, loss): with ops.name_scope(name): ema = training.ExponentialMovingAverage(decay=ema_decay) - update_op = ema.apply(math_ops.reduce_mean(loss)) - with control_flow_ops.control_dependencies([update_op]): + reduced_loss = math_ops.reduce_mean(loss) + update_op = ema.apply([reduced_loss]) + with ops.control_dependencies([update_op]): # TODO(rsepassi): Possibly implement the initialization bias correction # term from Adam (section 3 of https://arxiv.org/pdf/1412.6980v8.pdf). - baseline = ema.average(loss) + baseline = ema.average(reduced_loss) return baseline return mean_baseline |