aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-06 09:41:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-06 10:48:04 -0700
commita4f6a51a84a245d7109a1a256d10785bd8829843 (patch)
treee8b9f91041755f5adeb143340b1ffa8af95e30a0 /tensorflow
parent8c71e3c554a05f689e887f809c1bbb8d168d47c9 (diff)
Add tests for stochastic gradient estimators and fix EMA usage bug
Change: 132336601
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/bayesflow/BUILD12
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py92
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py8
3 files changed, 108 insertions, 4 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 37ce250843..84bae2f768 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -90,6 +90,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "stochastic_gradient_estimators_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/stochastic_gradient_estimators_test.py"],
+ additional_deps = [
+ ":bayesflow_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "reinforce_simple_example",
size = "small",
srcs = ["examples/reinforce_simple/reinforce_simple_example.py"],
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
new file mode 100644
index 0000000000..56936e6c38
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
@@ -0,0 +1,92 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for stochastic graphs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+st = tf.contrib.bayesflow.stochastic_tensor
+sge = tf.contrib.bayesflow.stochastic_gradient_estimators
+
+
+class StochasticGradientEstimatorsTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._p = tf.constant(0.999999)
+ self._final_loss = tf.constant(3.2)
+
+ def _testScoreFunction(self, loss_fn, expected):
+ x = st.BernoulliTensor(p=self._p, loss_fn=loss_fn)
+ sf = x.loss(self._final_loss)
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+ self.assertAllClose(*sess.run([expected, sf]))
+
+ def testScoreFunction(self):
+ expected = tf.log(self._p) * self._final_loss
+ self._testScoreFunction(sge.score_function, expected)
+
+ def testScoreFunctionWithConstantBaseline(self):
+ b = tf.constant(9.8)
+ expected = tf.log(self._p) * (self._final_loss - b)
+ self._testScoreFunction(
+ sge.get_score_function_with_constant_baseline(b), expected)
+
+ def testScoreFunctionWithBaselineFn(self):
+ b = tf.constant(9.8)
+
+ def baseline_fn(stoch_tensor, loss):
+ self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor))
+ self.assertTrue(isinstance(loss, tf.Tensor))
+ return b
+
+ expected = tf.log(self._p) * (self._final_loss - b)
+ self._testScoreFunction(
+ sge.get_score_function_with_baseline(baseline_fn), expected)
+
+ def testScoreFunctionWithMeanBaseline(self):
+ ema_decay = 0.8
+ x = st.BernoulliTensor(
+ p=self._p,
+ loss_fn=sge.get_score_function_with_baseline(
+ sge.get_mean_baseline(ema_decay)))
+ sf = x.loss(self._final_loss)
+
+ expected = tf.log(self._p) * (self._final_loss -
+ (1. - ema_decay) * self._final_loss)
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+ sess.run(sf) # run to update EMA
+ self.assertAllClose(*sess.run([expected, sf]))
+
+ def testScoreFunctionWithAdvantageFn(self):
+ b = tf.constant(9.8)
+
+ def advantage_fn(stoch_tensor, loss):
+ self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor))
+ self.assertTrue(isinstance(loss, tf.Tensor))
+ return loss - b
+
+ expected = tf.log(self._p) * (self._final_loss - b)
+ self._testScoreFunction(
+ sge.get_score_function_with_advantage(advantage_fn), expected)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
index 62b8446131..7cb8ef06f9 100644
--- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
+++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
@@ -58,7 +58,6 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import training
from tensorflow.python.util.all_util import make_all
@@ -180,11 +179,12 @@ def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"):
def mean_baseline(_, loss):
with ops.name_scope(name):
ema = training.ExponentialMovingAverage(decay=ema_decay)
- update_op = ema.apply(math_ops.reduce_mean(loss))
- with control_flow_ops.control_dependencies([update_op]):
+ reduced_loss = math_ops.reduce_mean(loss)
+ update_op = ema.apply([reduced_loss])
+ with ops.control_dependencies([update_op]):
# TODO(rsepassi): Possibly implement the initialization bias correction
# term from Adam (section 3 of https://arxiv.org/pdf/1412.6980v8.pdf).
- baseline = ema.average(loss)
+ baseline = ema.average(reduced_loss)
return baseline
return mean_baseline