aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-11-21 17:31:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-21 17:45:08 -0800
commitee983d8ae763ca9dbce2379f6899554509b53d91 (patch)
treec65e030384ba26107ae399b0c13a5b9dcf6baea0 /tensorflow/contrib/bayesflow
parentfff98d1db9648269e3dda1958dcfb3d76120dfd3 (diff)
Add VIMCO advantage function to bayesflow.
Change: 139853413
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py76
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py120
2 files changed, 196 insertions, 0 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
index ff8f70bea4..c6497db9ed 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
import tensorflow as tf
st = tf.contrib.bayesflow.stochastic_tensor
@@ -25,6 +26,31 @@ sge = tf.contrib.bayesflow.stochastic_gradient_estimators
dists = tf.contrib.distributions
+def _vimco(loss):
+ """Python implementation of VIMCO."""
+ n = loss.shape[0]
+ log_loss = np.log(loss)
+ geometric_mean = []
+ for j in range(n):
+ geometric_mean.append(
+ np.exp(np.mean([log_loss[i, :] for i in range(n) if i != j], 0)))
+ geometric_mean = np.array(geometric_mean)
+
+ learning_signal = []
+ for j in range(n):
+ learning_signal.append(
+ np.sum([loss[i, :] for i in range(n) if i != j], 0))
+ learning_signal = np.array(learning_signal)
+
+ local_learning_signal = np.log(1/n * (learning_signal + geometric_mean))
+
+ # log_mean - local_learning_signal
+ log_mean = np.log(np.mean(loss, 0))
+ advantage = log_mean - local_learning_signal
+
+ return advantage
+
+
class StochasticGradientEstimatorsTest(tf.test.TestCase):
def setUp(self):
@@ -97,6 +123,56 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase):
self._testScoreFunction(
sge.get_score_function_with_advantage(advantage_fn), expected)
+ def testVIMCOAdvantageFn(self):
+ # simple_loss: (3, 2) with 3 samples, batch size 2
+ simple_loss = np.array(
+ [[1.0, 1.5],
+ [1e-6, 1e4],
+ [2.0, 3.0]])
+ # random_loss: (100, 50, 64) with 100 samples, batch shape (50, 64)
+ random_loss = 100*np.random.rand(100, 50, 64)
+
+ advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=False)
+
+ with self.test_session() as sess:
+ for loss in [simple_loss, random_loss]:
+ expected = _vimco(loss)
+ loss_t = tf.constant(loss, dtype=tf.float32)
+ advantage_t = advantage_fn(None, loss_t) # ST is not used
+ advantage = sess.run(advantage_t)
+ self.assertEqual(expected.shape, advantage_t.get_shape())
+ self.assertAllClose(expected, advantage, atol=5e-5)
+
+ def testVIMCOAdvantageGradients(self):
+ loss = np.log(
+ [[1.0, 1.5],
+ [1e-6, 1e4],
+ [2.0, 3.0]])
+ advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True)
+
+ with self.test_session():
+ loss_t = tf.constant(loss, dtype=tf.float64)
+ advantage_t = advantage_fn(None, loss_t) # ST is not used
+ gradient_error = tf.test.compute_gradient_error(
+ loss_t, loss_t.get_shape().as_list(),
+ advantage_t, advantage_t.get_shape().as_list(),
+ x_init_value=loss)
+ self.assertLess(gradient_error, 1e-3)
+
+ def testVIMCOAdvantageWithSmallProbabilities(self):
+ theta_value = np.random.rand(10, 100000)
+ # Test with float16 dtype to ensure stability even in this extreme case.
+ theta = tf.constant(theta_value, dtype=tf.float16)
+ advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True)
+
+ with self.test_session() as sess:
+ log_loss = -tf.reduce_sum(theta, [1])
+ advantage_t = advantage_fn(None, log_loss)
+ grad_t = tf.gradients(advantage_t, theta)[0]
+ advantage, grad = sess.run((advantage_t, grad_t))
+ self.assertTrue(np.all(np.isfinite(advantage)))
+ self.assertTrue(np.all(np.isfinite(grad)))
+
def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self):
ema_decay = 0.8
x = st.StochasticTensor(
diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
index 2139419289..64488ebb10 100644
--- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
+++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py
@@ -56,6 +56,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -194,4 +196,122 @@ def get_mean_baseline(ema_decay=0.99, name=None):
return mean_baseline
+def get_vimco_advantage_fn(have_log_loss=False):
+ """VIMCO (Variational Inference for Monte Carlo Objectives) baseline.
+
+ Implements VIMCO baseline from the article of the same name:
+
+ https://arxiv.org/pdf/1602.06725v2.pdf
+
+ Given a `loss` tensor (containing non-negative probabilities or ratios),
+ calculates the advantage VIMCO advantage via Eq. 9 of the above paper.
+
+ The tensor `loss` should be shaped `[n, ...]`, with rank at least 1. Here,
+ the first axis is considered the single sampling dimension and `n` must
+ be at least 2. Specifically, the `StochasticTensor` is assumed to have
+ used the `SampleValue(n)` value type with `n > 1`.
+
+ Args:
+ have_log_loss: Python `Boolean`. If `True`, the loss is assumed to be the
+ log loss. If `False` (the default), it is assumed to be a nonnegative
+ probability or probability ratio.
+
+ Returns:
+ Callable baseline function that takes the `StochasticTensor` (unused) and
+ the downstream `loss`, and returns the VIMCO baseline for the loss.
+ """
+ def vimco_advantage_fn(_, loss, name=None):
+ """Internal VIMCO function.
+
+ Args:
+ _: ignored `StochasticTensor`.
+ loss: The loss `Tensor`.
+ name: Python string, the name scope to use.
+
+ Returns:
+ The advantage `Tensor`.
+ """
+ with ops.name_scope(name, "VIMCOAdvantage", values=[loss]):
+ loss = ops.convert_to_tensor(loss)
+ loss_shape = loss.get_shape()
+ loss_num_elements = loss_shape[0].value
+ n = math_ops.cast(
+ loss_num_elements or array_ops.shape(loss)[0], dtype=loss.dtype)
+
+ if have_log_loss:
+ log_loss = loss
+ else:
+ log_loss = math_ops.log(loss)
+
+ # Calculate L_hat, Eq. (4) -- stably
+ log_mean = math_ops.reduce_logsumexp(log_loss, [0]) - math_ops.log(n)
+
+ # expand_dims: Expand shape [a, b, c] to [a, 1, b, c]
+ log_loss_expanded = array_ops.expand_dims(log_loss, [1])
+
+ # divide: log_loss_sub with shape [a, a, b, c], where
+ #
+ # log_loss_sub[i] = log_loss - log_loss[i]
+ #
+ # = [ log_loss[j] - log_loss[i] for rows j = 0 ... i - 1 ]
+ # [ zeros ]
+ # [ log_loss[j] - log_loss[i] for rows j = i + 1 ... a - 1 ]
+ #
+ log_loss_sub = log_loss - log_loss_expanded
+
+ # reduce_sum: Sums each row across all the sub[i]'s; result is:
+ # reduce_sum[j] = (n - 1) * log_loss[j] - (sum_{i != j} loss[i])
+ # divide by (n - 1) to get:
+ # geometric_reduction[j] =
+ # log_loss[j] - (sum_{i != j} log_loss[i]) / (n - 1)
+ geometric_reduction = math_ops.reduce_sum(log_loss_sub, [0]) / (n - 1)
+
+ # subtract this from the original log_loss to get the baseline:
+ # geometric_mean[j] = exp((sum_{i != j} log_loss[i]) / (n - 1))
+ log_geometric_mean = log_loss - geometric_reduction
+
+ ## Equation (9)
+
+ # Calculate sum_{i != j} loss[i] -- via exp(reduce_logsumexp(.))
+ # reduce_logsumexp: log-sum-exp each row across all the
+ # -sub[i]'s, result is:
+ #
+ # exp(reduce_logsumexp[j]) =
+ # 1 + sum_{i != j} exp(log_loss[i] - log_loss[j])
+ log_local_learning_reduction = math_ops.reduce_logsumexp(
+ -log_loss_sub, [0])
+
+ # convert local_learning_reduction to the sum-exp of the log-sum-exp
+ # (local_learning_reduction[j] - 1) * exp(log_loss[j])
+ # = sum_{i != j} exp(log_loss[i])
+ local_learning_log_sum = (
+ _logexpm1(log_local_learning_reduction) + log_loss)
+
+ # Add (logaddexp) the local learning signals (Eq. 9)
+ local_learning_signal = (
+ math_ops.reduce_logsumexp(
+ array_ops.stack((local_learning_log_sum, log_geometric_mean)),
+ [0])
+ - math_ops.log(n))
+
+ advantage = log_mean - local_learning_signal
+
+ return advantage
+
+ return vimco_advantage_fn
+
+
+def _logexpm1(x):
+ """Stably calculate log(exp(x)-1)."""
+ with ops.name_scope("logsumexp1"):
+ eps = np.finfo(x.dtype.as_numpy_dtype).eps
+ # Choose a small offset that makes gradient calculations stable for
+ # float16, float32, and float64.
+ safe_log = lambda y: math_ops.log(y + eps / 1e8) # For gradient stability
+ return array_ops.where(
+ math_ops.abs(x) < eps,
+ safe_log(x) + x/2 + x*x/24, # small x approximation to log(expm1(x))
+ safe_log(math_ops.exp(x) - 1))
+
+
__all__ = make_all(__name__)