diff options
author | 2016-11-21 17:31:54 -0800 | |
---|---|---|
committer | 2016-11-21 17:45:08 -0800 | |
commit | ee983d8ae763ca9dbce2379f6899554509b53d91 (patch) | |
tree | c65e030384ba26107ae399b0c13a5b9dcf6baea0 /tensorflow/contrib/bayesflow | |
parent | fff98d1db9648269e3dda1958dcfb3d76120dfd3 (diff) |
Add VIMCO advantage function to bayesflow.
Change: 139853413
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py | 76 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py | 120 |
2 files changed, 196 insertions, 0 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py index ff8f70bea4..c6497db9ed 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import tensorflow as tf st = tf.contrib.bayesflow.stochastic_tensor @@ -25,6 +26,31 @@ sge = tf.contrib.bayesflow.stochastic_gradient_estimators dists = tf.contrib.distributions +def _vimco(loss): + """Python implementation of VIMCO.""" + n = loss.shape[0] + log_loss = np.log(loss) + geometric_mean = [] + for j in range(n): + geometric_mean.append( + np.exp(np.mean([log_loss[i, :] for i in range(n) if i != j], 0))) + geometric_mean = np.array(geometric_mean) + + learning_signal = [] + for j in range(n): + learning_signal.append( + np.sum([loss[i, :] for i in range(n) if i != j], 0)) + learning_signal = np.array(learning_signal) + + local_learning_signal = np.log(1/n * (learning_signal + geometric_mean)) + + # log_mean - local_learning_signal + log_mean = np.log(np.mean(loss, 0)) + advantage = log_mean - local_learning_signal + + return advantage + + class StochasticGradientEstimatorsTest(tf.test.TestCase): def setUp(self): @@ -97,6 +123,56 @@ class StochasticGradientEstimatorsTest(tf.test.TestCase): self._testScoreFunction( sge.get_score_function_with_advantage(advantage_fn), expected) + def testVIMCOAdvantageFn(self): + # simple_loss: (3, 2) with 3 samples, batch size 2 + simple_loss = np.array( + [[1.0, 1.5], + [1e-6, 1e4], + [2.0, 3.0]]) + # random_loss: (100, 50, 64) with 100 samples, batch shape (50, 64) + random_loss = 100*np.random.rand(100, 50, 64) + + advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=False) + + with self.test_session() as sess: + for loss in [simple_loss, random_loss]: + expected = _vimco(loss) + loss_t = tf.constant(loss, dtype=tf.float32) + advantage_t = advantage_fn(None, loss_t) # ST is not used + advantage = sess.run(advantage_t) + self.assertEqual(expected.shape, advantage_t.get_shape()) + self.assertAllClose(expected, advantage, atol=5e-5) + + def testVIMCOAdvantageGradients(self): + loss = np.log( + [[1.0, 1.5], + [1e-6, 1e4], + [2.0, 3.0]]) + advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True) + + with self.test_session(): + loss_t = tf.constant(loss, dtype=tf.float64) + advantage_t = advantage_fn(None, loss_t) # ST is not used + gradient_error = tf.test.compute_gradient_error( + loss_t, loss_t.get_shape().as_list(), + advantage_t, advantage_t.get_shape().as_list(), + x_init_value=loss) + self.assertLess(gradient_error, 1e-3) + + def testVIMCOAdvantageWithSmallProbabilities(self): + theta_value = np.random.rand(10, 100000) + # Test with float16 dtype to ensure stability even in this extreme case. + theta = tf.constant(theta_value, dtype=tf.float16) + advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True) + + with self.test_session() as sess: + log_loss = -tf.reduce_sum(theta, [1]) + advantage_t = advantage_fn(None, log_loss) + grad_t = tf.gradients(advantage_t, theta)[0] + advantage, grad = sess.run((advantage_t, grad_t)) + self.assertTrue(np.all(np.isfinite(advantage))) + self.assertTrue(np.all(np.isfinite(grad))) + def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self): ema_decay = 0.8 x = st.StochasticTensor( diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py index 2139419289..64488ebb10 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py +++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py @@ -56,6 +56,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -194,4 +196,122 @@ def get_mean_baseline(ema_decay=0.99, name=None): return mean_baseline +def get_vimco_advantage_fn(have_log_loss=False): + """VIMCO (Variational Inference for Monte Carlo Objectives) baseline. + + Implements VIMCO baseline from the article of the same name: + + https://arxiv.org/pdf/1602.06725v2.pdf + + Given a `loss` tensor (containing non-negative probabilities or ratios), + calculates the advantage VIMCO advantage via Eq. 9 of the above paper. + + The tensor `loss` should be shaped `[n, ...]`, with rank at least 1. Here, + the first axis is considered the single sampling dimension and `n` must + be at least 2. Specifically, the `StochasticTensor` is assumed to have + used the `SampleValue(n)` value type with `n > 1`. + + Args: + have_log_loss: Python `Boolean`. If `True`, the loss is assumed to be the + log loss. If `False` (the default), it is assumed to be a nonnegative + probability or probability ratio. + + Returns: + Callable baseline function that takes the `StochasticTensor` (unused) and + the downstream `loss`, and returns the VIMCO baseline for the loss. + """ + def vimco_advantage_fn(_, loss, name=None): + """Internal VIMCO function. + + Args: + _: ignored `StochasticTensor`. + loss: The loss `Tensor`. + name: Python string, the name scope to use. + + Returns: + The advantage `Tensor`. + """ + with ops.name_scope(name, "VIMCOAdvantage", values=[loss]): + loss = ops.convert_to_tensor(loss) + loss_shape = loss.get_shape() + loss_num_elements = loss_shape[0].value + n = math_ops.cast( + loss_num_elements or array_ops.shape(loss)[0], dtype=loss.dtype) + + if have_log_loss: + log_loss = loss + else: + log_loss = math_ops.log(loss) + + # Calculate L_hat, Eq. (4) -- stably + log_mean = math_ops.reduce_logsumexp(log_loss, [0]) - math_ops.log(n) + + # expand_dims: Expand shape [a, b, c] to [a, 1, b, c] + log_loss_expanded = array_ops.expand_dims(log_loss, [1]) + + # divide: log_loss_sub with shape [a, a, b, c], where + # + # log_loss_sub[i] = log_loss - log_loss[i] + # + # = [ log_loss[j] - log_loss[i] for rows j = 0 ... i - 1 ] + # [ zeros ] + # [ log_loss[j] - log_loss[i] for rows j = i + 1 ... a - 1 ] + # + log_loss_sub = log_loss - log_loss_expanded + + # reduce_sum: Sums each row across all the sub[i]'s; result is: + # reduce_sum[j] = (n - 1) * log_loss[j] - (sum_{i != j} loss[i]) + # divide by (n - 1) to get: + # geometric_reduction[j] = + # log_loss[j] - (sum_{i != j} log_loss[i]) / (n - 1) + geometric_reduction = math_ops.reduce_sum(log_loss_sub, [0]) / (n - 1) + + # subtract this from the original log_loss to get the baseline: + # geometric_mean[j] = exp((sum_{i != j} log_loss[i]) / (n - 1)) + log_geometric_mean = log_loss - geometric_reduction + + ## Equation (9) + + # Calculate sum_{i != j} loss[i] -- via exp(reduce_logsumexp(.)) + # reduce_logsumexp: log-sum-exp each row across all the + # -sub[i]'s, result is: + # + # exp(reduce_logsumexp[j]) = + # 1 + sum_{i != j} exp(log_loss[i] - log_loss[j]) + log_local_learning_reduction = math_ops.reduce_logsumexp( + -log_loss_sub, [0]) + + # convert local_learning_reduction to the sum-exp of the log-sum-exp + # (local_learning_reduction[j] - 1) * exp(log_loss[j]) + # = sum_{i != j} exp(log_loss[i]) + local_learning_log_sum = ( + _logexpm1(log_local_learning_reduction) + log_loss) + + # Add (logaddexp) the local learning signals (Eq. 9) + local_learning_signal = ( + math_ops.reduce_logsumexp( + array_ops.stack((local_learning_log_sum, log_geometric_mean)), + [0]) + - math_ops.log(n)) + + advantage = log_mean - local_learning_signal + + return advantage + + return vimco_advantage_fn + + +def _logexpm1(x): + """Stably calculate log(exp(x)-1).""" + with ops.name_scope("logsumexp1"): + eps = np.finfo(x.dtype.as_numpy_dtype).eps + # Choose a small offset that makes gradient calculations stable for + # float16, float32, and float64. + safe_log = lambda y: math_ops.log(y + eps / 1e8) # For gradient stability + return array_ops.where( + math_ops.abs(x) < eps, + safe_log(x) + x/2 + x*x/24, # small x approximation to log(expm1(x)) + safe_log(math_ops.exp(x) - 1)) + + __all__ = make_all(__name__) |