aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-07-21 17:51:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 17:55:10 -0700
commit9c08e99a02c682acca0dcb9be9f79c3a9d38b615 (patch)
tree1e3c55bac1f1ffc6b7ce480ecbe795c9cd7304ba
parent4895cec77fac920a77a79e9e6f79951503483303 (diff)
Implement VIMCO-like objective for approx Csiszar f-Divergence. Simplify monte_carlo.expectation_v2 calculation when doing score-trick.
PiperOrigin-RevId: 162806278
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py208
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py90
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py1
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py202
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py107
5 files changed, 544 insertions, 64 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
index dda3f60065..d06b69885c 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
@@ -37,6 +37,14 @@ from tensorflow.python.platform import test
cd = csiszar_divergence_impl
+def tridiag(d, diag_value, offdiag_value):
+ """d x d matrix with given value on diag, and one super/sub diag."""
+ diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value)
+ three_bands = array_ops.matrix_band_part(
+ array_ops.fill([d, d], offdiag_value), 1, 1)
+ return diag_mat + three_bands
+
+
class AmariAlphaTest(test.TestCase):
def setUp(self):
@@ -483,14 +491,14 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
approx_kl = cd.monte_carlo_csiszar_f_divergence(
f=cd.kl_forward,
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
f=lambda logu: cd.kl_forward(logu, self_normalized=True),
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
@@ -517,14 +525,14 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
approx_kl = cd.monte_carlo_csiszar_f_divergence(
f=cd.kl_reverse,
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
@@ -540,33 +548,26 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
rtol=0.02, atol=0.)
- def _tridiag(self, d, diag_value, offdiag_value):
- """d x d matrix with given value on diag, and one super/sub diag."""
- diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value)
- three_bands = array_ops.matrix_band_part(
- array_ops.fill([d, d], offdiag_value), 1, 1)
- return diag_mat + three_bands
-
def test_kl_reverse_multidim(self):
with self.test_session() as sess:
d = 5 # Dimension
p = mvn_full_lib.MultivariateNormalFullCovariance(
- covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5))
+ covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d)
approx_kl = cd.monte_carlo_csiszar_f_divergence(
f=cd.kl_reverse,
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
@@ -588,7 +589,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
d = 5 # Dimension
p = mvn_full_lib.MultivariateNormalFullCovariance(
- covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5))
+ covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
# Variance is very high when approximating Forward KL, so we make
# scale_diag larger than in test_kl_reverse_multidim. This ensures q
@@ -597,14 +598,14 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
approx_kl = cd.monte_carlo_csiszar_f_divergence(
f=cd.kl_forward,
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
f=lambda logu: cd.kl_forward(logu, self_normalized=True),
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=int(1e5),
seed=1)
@@ -628,7 +629,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
seed = 1
p = mvn_full_lib.MultivariateNormalFullCovariance(
- covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5))
+ covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5))
# Variance is very high when approximating Forward KL, so we make
# scale_diag larger than in test_kl_reverse_multidim. This ensures q
@@ -639,21 +640,21 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
approx_kl = cd.monte_carlo_csiszar_f_divergence(
f=cd.kl_reverse,
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=num_draws,
seed=seed)
approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=num_draws,
seed=seed)
approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence(
f=cd.kl_reverse,
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=num_draws,
use_reparametrization=False,
@@ -662,7 +663,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
approx_kl_self_normalized_score_trick = (
cd.monte_carlo_csiszar_f_divergence(
f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
- p=p,
+ p_log_prob=p.log_prob,
q=q,
num_draws=num_draws,
use_reparametrization=False,
@@ -670,7 +671,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
exact_kl = kullback_leibler.kl_divergence(q, p)
- grad = lambda fs: gradients_impl.gradients(fs, s)[0]
+ grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0]
[
approx_kl_grad_,
@@ -684,11 +685,11 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
approx_kl_self_normalized_score_trick_,
exact_kl_,
] = sess.run([
- grad(approx_kl),
- grad(approx_kl_self_normalized),
- grad(approx_kl_score_trick),
- grad(approx_kl_self_normalized_score_trick),
- grad(exact_kl),
+ grad_sum(approx_kl),
+ grad_sum(approx_kl_self_normalized),
+ grad_sum(approx_kl_score_trick),
+ grad_sum(approx_kl_self_normalized_score_trick),
+ grad_sum(exact_kl),
approx_kl,
approx_kl_self_normalized,
approx_kl_score_trick,
@@ -724,5 +725,154 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
rtol=0.017, atol=0.)
-if __name__ == '__main__':
+class CsiszarVIMCOTest(test.TestCase):
+
+ def _numpy_csiszar_vimco_helper(self, logu):
+ """Numpy implementation of `csiszar_vimco_helper`."""
+ n = logu.shape[0]
+ u = np.exp(logu)
+ loogeoavg_u = [] # Leave-one-out geometric-average of exp(logu).
+ for j in range(n):
+ loogeoavg_u.append(np.exp(np.mean(
+ [logu[i, ...] for i in range(n) if i != j],
+ axis=0)))
+ loogeoavg_u = np.array(loogeoavg_u)
+
+ loosum_u = [] # Leave-one-out sum of exp(logu).
+ for j in range(n):
+ loosum_u.append(np.sum(
+ [u[i, ...] for i in range(n) if i != j],
+ axis=0))
+ loosum_u = np.array(loosum_u)
+
+ # Natural log of the average u except each is swapped-out for its
+ # leave-`i`-th-out Geometric average.
+ log_sooavg_u = np.log(loosum_u + loogeoavg_u) - np.log(n)
+
+ log_avg_u = np.log(np.mean(u, axis=0))
+ return log_avg_u, log_sooavg_u
+
+ def test_vimco_helper(self):
+
+ with self.test_session() as sess:
+ logu = np.linspace(-20, 20, 100)
+ np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper(logu)
+ [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu))
+ self.assertAllClose(np_log_avg_u, log_avg_u,
+ rtol=1e-2, atol=0.)
+ self.assertAllClose(np_log_sooavg_u, log_sooavg_u,
+ rtol=1e-2, atol=0.)
+
+ def test_vimco_helper_gradient(self):
+
+ with self.test_session():
+ logu = array_ops.constant(
+ np.linspace(-1e2, 100., 100).reshape([50, 2]))
+ log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu)
+ g = gradients_impl.gradients(log_avg_u - log_sooavg_u, logu)[0].eval()
+ self.assertAllEqual(np.ones_like(g, dtype=np.bool), np.isfinite(g))
+ self.assertAllEqual(np.ones_like(g, dtype=np.bool), g != 0.)
+
+ def test_vimco_and_gradient(self):
+
+ with self.test_session() as sess:
+ dims = 5 # Dimension
+ num_draws = int(20)
+ num_batch_draws = int(3)
+ seed = 1
+
+ f = lambda logu: cd.kl_reverse(logu, self_normalized=False)
+ np_f = lambda logu: -logu
+
+ p = mvn_full_lib.MultivariateNormalFullCovariance(
+ covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5))
+
+ # Variance is very high when approximating Forward KL, so we make
+ # scale_diag larger than in test_kl_reverse_multidim. This ensures q
+ # "covers" p and thus Var_q[p/q] is smaller.
+ s = array_ops.constant(1.)
+ q = mvn_diag_lib.MultivariateNormalDiag(
+ scale_diag=array_ops.tile([s], [dims]))
+
+ vimco = cd.csiszar_vimco(
+ f=f,
+ p_log_prob=p.log_prob,
+ q=q,
+ num_draws=num_draws,
+ num_batch_draws=num_batch_draws,
+ seed=seed)
+
+ x = q.sample(sample_shape=[num_draws, num_batch_draws],
+ seed=seed)
+ x = array_ops.stop_gradient(x)
+ logu = p.log_prob(x) - q.log_prob(x)
+ f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0])
+
+ grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0]
+
+ def jacobian(x):
+ # Warning: this function is slow and may not even finish if prod(shape)
+ # is larger than, say, 100.
+ shape = x.shape.as_list()
+ assert all(s is not None for s in shape)
+ x = array_ops.reshape(x, shape=[-1])
+ r = [grad_sum(x[i]) for i in range(np.prod(shape))]
+ return array_ops.reshape(array_ops.stack(r), shape=shape)
+
+ [
+ logu_,
+ jacobian_logqx_,
+ vimco_,
+ grad_vimco_,
+ f_log_sum_u_,
+ grad_mean_f_log_sum_u_,
+ ] = sess.run([
+ logu,
+ jacobian(q.log_prob(x)),
+ vimco,
+ grad_sum(vimco),
+ f_log_sum_u,
+ grad_sum(f_log_sum_u) / num_batch_draws,
+ ])
+
+ np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper(logu_)
+
+ # Test VIMCO loss is correct.
+ self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_,
+ rtol=1e-5, atol=0.)
+
+ # Test gradient of VIMCO loss is correct.
+ #
+ # To make this computation we'll inject two gradients from TF:
+ # - grad[mean(f(log(sum(p(x)/q(x)))))]
+ # - jacobian[log(q(x))].
+ #
+ # We now justify why using these (and only these) TF values for
+ # ground-truth does not undermine the completeness of this test.
+ #
+ # Regarding `grad_mean_f_log_sum_u_`, note that we validate the
+ # correctness of the zero-th order derivative (for each batch member).
+ # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient
+ # information, we can safely rely on TF.
+ self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.)
+ #
+ # Regarding `jacobian_logqx_`, note that testing the gradient of
+ # `q.log_prob` is outside the scope of this unit-test thus we may safely
+ # use TF to find it.
+
+ # The `mean` is across batches and the `sum` is across iid samples.
+ np_grad_vimco = (
+ grad_mean_f_log_sum_u_
+ + np.mean(
+ np.sum(
+ jacobian_logqx_ * (np_f(np_log_avg_u)
+ - np_f(np_log_sooavg_u)),
+ axis=0),
+ axis=0))
+
+ self.assertAllClose(np_grad_vimco, grad_vimco_,
+ rtol=1e-5, atol=0.)
+
+
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index 38426ebf1f..d9e23646d8 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -28,6 +28,9 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.ops.distributions import gamma as gamma_lib
+from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
@@ -210,6 +213,93 @@ class ExpectationTest(test.TestCase):
efx_score_grad_[2:-2],
rtol=0.05, atol=0.)
+ def test_docstring_example_normal(self):
+ with self.test_session() as sess:
+ num_draws = int(1e5)
+ mu_p = constant_op.constant(0.)
+ mu_q = constant_op.constant(1.)
+ p = normal_lib.Normal(loc=mu_p, scale=1.)
+ q = normal_lib.Normal(loc=mu_q, scale=2.)
+ exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q)
+ approx_kl_normal_normal = monte_carlo_lib.expectation(
+ f=lambda x: p.log_prob(x) - q.log_prob(x),
+ samples=p.sample(num_draws, seed=42),
+ log_prob=p.log_prob,
+ use_reparametrization=(p.reparameterization_type
+ == distribution_lib.FULLY_REPARAMETERIZED))
+ [exact_kl_normal_normal_, approx_kl_normal_normal_] = sess.run([
+ exact_kl_normal_normal, approx_kl_normal_normal])
+ self.assertEqual(
+ True,
+ p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
+ self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_,
+ rtol=0.01, atol=0.)
+
+ # Compare gradients. (Not present in `docstring`.)
+ gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0]
+ gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0]
+ [
+ gradp_exact_kl_normal_normal_,
+ gradq_exact_kl_normal_normal_,
+ gradp_approx_kl_normal_normal_,
+ gradq_approx_kl_normal_normal_,
+ ] = sess.run([
+ gradp(exact_kl_normal_normal),
+ gradq(exact_kl_normal_normal),
+ gradp(approx_kl_normal_normal),
+ gradq(approx_kl_normal_normal),
+ ])
+ self.assertAllClose(gradp_exact_kl_normal_normal_,
+ gradp_approx_kl_normal_normal_,
+ rtol=0.01, atol=0.)
+ self.assertAllClose(gradq_exact_kl_normal_normal_,
+ gradq_approx_kl_normal_normal_,
+ rtol=0.01, atol=0.)
+
+ def test_docstring_example_gamma(self):
+ with self.test_session() as sess:
+ num_draws = int(1e5)
+ concentration_p = constant_op.constant(1.)
+ concentration_q = constant_op.constant(2.)
+ p = gamma_lib.Gamma(concentration=concentration_p, rate=1.)
+ q = gamma_lib.Gamma(concentration=concentration_q, rate=3.)
+ approx_kl_gamma_gamma = monte_carlo_lib.expectation(
+ f=lambda x: p.log_prob(x) - q.log_prob(x),
+ samples=p.sample(num_draws, seed=42),
+ log_prob=p.log_prob,
+ use_reparametrization=(p.reparameterization_type
+ == distribution_lib.FULLY_REPARAMETERIZED))
+ exact_kl_gamma_gamma = kullback_leibler.kl_divergence(p, q)
+ [exact_kl_gamma_gamma_, approx_kl_gamma_gamma_] = sess.run([
+ exact_kl_gamma_gamma, approx_kl_gamma_gamma])
+ self.assertEqual(
+ False,
+ p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
+ self.assertAllClose(exact_kl_gamma_gamma_, approx_kl_gamma_gamma_,
+ rtol=0.01, atol=0.)
+
+ # Compare gradients. (Not present in `docstring`.)
+ gradp = lambda fp: gradients_impl.gradients(fp, concentration_p)[0]
+ gradq = lambda fq: gradients_impl.gradients(fq, concentration_q)[0]
+ [
+ gradp_exact_kl_gamma_gamma_,
+ gradq_exact_kl_gamma_gamma_,
+ gradp_approx_kl_gamma_gamma_,
+ gradq_approx_kl_gamma_gamma_,
+ ] = sess.run([
+ gradp(exact_kl_gamma_gamma),
+ gradq(exact_kl_gamma_gamma),
+ gradp(approx_kl_gamma_gamma),
+ gradq(approx_kl_gamma_gamma),
+ ])
+ # Notice that variance (i.e., `rtol`) is higher when using score-trick.
+ self.assertAllClose(gradp_exact_kl_gamma_gamma_,
+ gradp_approx_kl_gamma_gamma_,
+ rtol=0.05, atol=0.)
+ self.assertAllClose(gradq_exact_kl_gamma_gamma_,
+ gradq_approx_kl_gamma_gamma_,
+ rtol=0.03, atol=0.)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py
index b1bcc86022..9f7a95f138 100644
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py
+++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py
@@ -31,6 +31,7 @@ _allowed_symbols = [
'amari_alpha',
'arithmetic_geometric',
'chi_square',
+ 'csiszar_vimco',
'dual_csiszar_function',
'jeffreys',
'jensen_shannon',
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
index 247d434bb7..54900ab893 100644
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
@@ -17,6 +17,7 @@
@@amari_alpha
@@arithmetic_geometric
@@chi_square
+@@csiszar_vimco
@@dual_csiszar_function
@@jeffreys
@@jensen_shannon
@@ -46,6 +47,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
def amari_alpha(logu, alpha=1., self_normalized=False, name=None):
@@ -696,7 +698,7 @@ def dual_csiszar_function(logu, csiszar_function, name=None):
Args:
logu: `float`-like `Tensor` representing `log(u)` from above.
- csiszar_function: Python callable representing a Csiszar-function over
+ csiszar_function: Python `callable` representing a Csiszar-function over
log-domain.
name: Python `str` name prefixed to Ops created by this function.
@@ -765,7 +767,7 @@ def symmetrized_csiszar_function(logu, csiszar_function, name=None):
Args:
logu: `float`-like `Tensor` representing `log(u)` from above.
- csiszar_function: Python callable representing a Csiszar-function over
+ csiszar_function: Python `callable` representing a Csiszar-function over
log-domain.
name: Python `str` name prefixed to Ops created by this function.
@@ -781,7 +783,13 @@ def symmetrized_csiszar_function(logu, csiszar_function, name=None):
def monte_carlo_csiszar_f_divergence(
- f, p, q, num_draws, use_reparametrization=True, seed=None, name=None):
+ f,
+ p_log_prob,
+ q,
+ num_draws,
+ use_reparametrization=None,
+ seed=None,
+ name=None):
"""Monte-Carlo approximation of the Csiszar f-Divergence.
A Csiszar-function is a member of,
@@ -843,15 +851,22 @@ def monte_carlo_csiszar_f_divergence(
"Evidence Divergence Bound Optimization" (EDBO).
Args:
- f: Python callable representing a Csiszar-function in log-space.
- p: `tf.Distribution`-like instance; must implement `log_prob(x)`.
+ f: Python `callable` representing a Csiszar-function in log-space, i.e.,
+ takes `p_log_prob(q_samples) - q.log_prob(q_samples)`.
+ p_log_prob: Python `callable` taking (a batch of) samples from `q` and
+ returning the the natural-log of the probability under distribution `p`.
+ (In variational inference `p` is the joint distribution.)
q: `tf.Distribution`-like instance; must implement:
- `reparameterization_type`, `sample(n)`, and `log_prob(x)`.
+ `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`.
+ (In variational inference `q` is the approximate posterior distribution.)
num_draws: Integer scalar number of draws used to approximate the
f-Divergence expectation.
- use_reparametrization: Python `bool`. When `True` uses the standard
- Monte-Carlo average. When `False` uses the score-gradient trick. (See
- above for details.)
+ use_reparametrization: Python `bool`. When `None` (the default),
+ automatically set to:
+ `q.reparameterization_type == distribution.FULLY_REPARAMETERIZED`.
+ When `True` uses the standard Monte-Carlo average. When `False` uses the
+ score-gradient trick. (See above for details.) When `False`, consider
+ using `csiszar_vimco`.
seed: Python `int` seed for `q.sample`.
name: Python `str` name prefixed to Ops created by this function.
@@ -866,18 +881,179 @@ def monte_carlo_csiszar_f_divergence(
samples of another distribution which does not depend on the
parameterization of `q`. This property ensures the gradient (with respect
to parameters) is valid.
+ TypeError: if `p_log_prob` is not a Python `callable`.
"""
with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]):
- if (use_reparametrization and
- q.reparameterization_type != distribution.FULLY_REPARAMETERIZED):
+ if use_reparametrization is None:
+ use_reparametrization = (q.reparameterization_type
+ == distribution.FULLY_REPARAMETERIZED)
+ elif (use_reparametrization and
+ q.reparameterization_type != distribution.FULLY_REPARAMETERIZED):
# TODO(jvdillon): Consider only raising an exception if the gradient is
# requested.
raise ValueError(
"Distribution `q` must be reparameterized, i.e., a diffeomorphic "
"transformation of a parameterless distribution. (Otherwise this "
"function has a biased gradient.)")
+ if not callable(p_log_prob):
+ raise TypeError("`p_log_prob` must be a Python `callable` function.")
return monte_carlo.expectation(
- f=lambda x: f(p.log_prob(x) - q.log_prob(x)),
+ f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)),
samples=q.sample(num_draws, seed=seed),
- log_prob=q.log_prob,
+ log_prob=q.log_prob, # Only used if use_reparametrization=False.
use_reparametrization=use_reparametrization)
+
+
+def csiszar_vimco(f,
+ p_log_prob,
+ q,
+ num_draws,
+ num_batch_draws=1,
+ seed=None,
+ name=None):
+ """Use VIMCO to lower the variance of gradient[csiszar_function(Avg(logu))].
+
+ This function generalizes "Variational Inference for Monte Carlo Objectives"
+ (VIMCO), i.e., https://arxiv.org/abs/1602.06725, to Csiszar f-Divergences.
+
+ Note: if `q.reparameterization_type = distribution.FULLY_REPARAMETERIZED`,
+ consider using `monte_carlo_csiszar_f_divergence`.
+
+ The VIMCO loss is:
+
+ ```none
+ vimco = f(Avg{logu[i] : i=0,...,m-1})
+ where,
+ logu[i] = log( p(x, h[i]) / q(h[i] | x) )
+ h[i] iid~ q(H | x)
+ ```
+
+ Interestingly, the VIMCO gradient is not the naive gradient of `vimco`.
+ Rather, it is characterized by:
+
+ ```none
+ grad[vimco] - variance_reducing_term
+ where,
+ variance_reducing_term = Sum{ grad[log q(h[i] | x)] *
+ (vimco - f(log Avg{h[j;i] : j=0,...,m-1}))
+ : i=0, ..., m-1 }
+ h[j;i] = { u[j] j!=i
+ { GeometricAverage{ u[k] : k!=i} j==i
+ ```
+
+ (We omitted `stop_gradient` for brevity. See implementation for more details.)
+
+ The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th
+ element has been replaced by the leave-`i`-out Geometric-average.
+
+ Args:
+ f: Python `callable` representing a Csiszar-function in log-space.
+ p_log_prob: Python `callable` representing the natural-log of the
+ probability under distribution `p`. (In variational inference `p` is the
+ joint distribution.)
+ q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and
+ `log_prob(x)`. (In variational inference `q` is the approximate posterior
+ distribution.)
+ num_draws: Integer scalar number of draws used to approximate the
+ f-Divergence expectation.
+ num_batch_draws: Integer scalar number of draws used to approximate the
+ f-Divergence expectation.
+ seed: Python `int` seed for `q.sample`.
+ name: Python `str` name prefixed to Ops created by this function.
+
+ Returns:
+ vimco: The Csiszar f-Divergence generalized VIMCO objective.
+
+ Raises:
+ ValueError: if `num_draws < 2`.
+ """
+ with ops.name_scope(name, "csiszar_vimco", [num_draws, num_batch_draws]):
+ if num_draws < 2:
+ raise ValueError("Must specify num_draws > 1.")
+ stop = array_ops.stop_gradient # For readability.
+ x = stop(q.sample(sample_shape=[num_draws, num_batch_draws],
+ seed=seed))
+ logqx = q.log_prob(x)
+ logu = p_log_prob(x) - logqx
+ f_log_avg_u, f_log_sooavg_u = [f(r) for r in csiszar_vimco_helper(logu)]
+ dotprod = math_ops.reduce_sum(
+ logqx * stop(f_log_avg_u - f_log_sooavg_u),
+ axis=0) # Sum over iid samples.
+ # We now rewrite f_log_avg_u so that:
+ # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`.
+ # To achieve this, we use a trick that
+ # `f(x) - stop(f(x)) == zeros_like(f(x))`
+ # but its gradient is grad[f(x)].
+ # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
+ # this trick loses no precision. For more discussion regarding the relevant
+ # portions of the IEEE754 standard, see the StackOverflow question,
+ # "Is there a floating point value of x, for which x-x == 0 is false?"
+ # http://stackoverflow.com/q/2686644
+ f_log_avg_u += dotprod - stop(dotprod) # Add zeros_like(dot_prod).
+ return math_ops.reduce_mean(f_log_avg_u, axis=0) # Avg over batches.
+
+
+def csiszar_vimco_helper(logu, name=None):
+ """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`.
+
+ `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e.,
+
+ ```none
+ logu[j] = log(u[j])
+ u[j] = p(x, h[j]) / q(h[j] | x)
+ h[j] iid~ q(H | x)
+ ```
+
+ Args:
+ logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`.
+ name: Python `str` name prefixed to Ops created by this function.
+
+ Returns:
+ log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the
+ average of `u`.
+ log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the
+ average of `u`` except that the average swaps-out `u[i]` for the
+ leave-`i`-out Geometric-average, i.e.,
+
+ ```none
+ log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1})
+ h[j ; i] = { u[j] j!=i
+ { GeometricAverage{u[k] : k != i} j==i
+ ```
+
+ """
+ with ops.name_scope(name, "csiszar_vimco_helper", [logu]):
+ logu = ops.convert_to_tensor(logu, name="logu")
+
+ n = logu.shape.with_rank_at_least(1)[0].value
+ if n is None:
+ n = array_ops.shape(logu)[0]
+ log_n = math_ops.log(math_ops.cast(n, dtype=logu.dtype))
+ nm1 = math_ops.cast(n - 1, dtype=logu.dtype)
+ else:
+ log_n = np.log(n).astype(logu.dtype.as_numpy_dtype)
+ nm1 = np.asarray(n - 1, dtype=logu.dtype.as_numpy_dtype)
+
+ # Throughout we reduce across axis=0 since this is presumed to be iid
+ # samples.
+
+ log_sum_u = math_ops.reduce_logsumexp(logu, axis=0)
+
+ # log_loosum_u[i] =
+ # = logsumexp(logu[j] : j != i)
+ # = log( exp(logsumexp(logu)) - exp(logu[i]) )
+ # = log( exp(logsumexp(logu - logu[i])) exp(logu[i]) - exp(logu[i]))
+ # = logu[i] + log(exp(logsumexp(logu - logu[i])) - 1)
+ # = logu[i] + softplus_inverse(logsumexp(logu - logu[i]))
+ log_loosum_u = logu + distribution_util.softplus_inverse(log_sum_u - logu)
+
+ # The swap-one-out-sum ("soosum") is n different sums, each of which
+ # replaces the i-th item with the i-th-left-out average, i.e.,
+ # soo_sum_u[i] = [exp(logu) - exp(logu[i])] + exp(mean(logu[!=i]))
+ # = exp(log_loosum_u[i]) + exp(looavg_logu[i])
+ looavg_logu = (math_ops.reduce_sum(logu, axis=0) - logu) / nm1
+ log_soosum_u = math_ops.reduce_logsumexp(
+ array_ops.stack([log_loosum_u, looavg_logu]),
+ axis=0)
+
+ return log_sum_u - log_n, log_soosum_u - log_n
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 0dae74d31f..985177e897 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -220,9 +220,10 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
`S_n = Avg{s_i}` and `s_i = f(x_i), x_i ~ p`.
However, if p is not reparameterized, TensorFlow's gradient will be incorrect
- since the chain-rule stops at samples of unreparameterized distributions. In
- this circumstance using the Score-Gradient trick results in an unbiased
- gradient, i.e.,
+ since the chain-rule stops at samples of non-reparameterized distributions.
+ (The non-differentiated result, `approx_expectation`, is the same regardless
+ of `use_reparametrization`.) In this circumstance using the Score-Gradient
+ trick results in an unbiased gradient, i.e.,
```none
grad[ E_p[f(X)] ]
@@ -240,6 +241,58 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
Warning: users are responsible for verifying `p` is a "reparameterized"
distribution.
+ Example Use:
+
+ ```python
+ bf = tf.contrib.bayesflow
+ ds = tf.contrib.distributions
+
+ # Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
+
+ num_draws = int(1e5)
+ p = ds.Normal(loc=0., scale=1.)
+ q = ds.Normal(loc=1., scale=2.)
+ exact_kl_normal_normal = ds.kl_divergence(p, q)
+ # ==> 0.44314718
+ approx_kl_normal_normal = bf.expectation(
+ f=lambda x: p.log_prob(x) - q.log_prob(x),
+ samples=p.sample(num_draws, seed=42),
+ log_prob=p.log_prob,
+ use_reparametrization=(p.reparameterization_type
+ == distribution.FULLY_REPARAMETERIZED))
+ # ==> 0.44632751
+ # Relative Error: <1%
+
+ # Monte-Carlo approximation of non-reparameterized distribution, e.g., Gamma.
+
+ num_draws = int(1e5)
+ p = ds.Gamma(concentration=1., rate=1.)
+ q = ds.Gamma(concentration=2., rate=3.)
+ exact_kl_gamma_gamma = ds.kl_divergence(p, q)
+ # ==> 0.37999129
+ approx_kl_gamma_gamma = bf.expectation(
+ f=lambda x: p.log_prob(x) - q.log_prob(x),
+ samples=p.sample(num_draws, seed=42),
+ log_prob=p.log_prob,
+ use_reparametrization=(p.reparameterization_type
+ == distribution.FULLY_REPARAMETERIZED))
+ # ==> 0.37696719
+ # Relative Error: <1%
+
+ # For comparing the gradients, see `monte_carlo_test.py`.
+ ```
+
+ Note: The above example is for illustration only. To compute approximate
+ KL-divergence, the following is preferred:
+
+ ```python
+ approx_kl_p_q = bf.monte_carlo_csiszar_f_divergence(
+ f=bf.kl_reverse,
+ p_log_prob=q.log_prob,
+ q=p,
+ num_draws=num_draws)
+ ```
+
Args:
f: Python callable which can return `f(samples)`.
samples: `Tensor` of samples used to form the Monte-Carlo approximation of
@@ -247,21 +300,27 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
log_prob: Python callable which can return `log_prob(samples)`. Must
correspond to the natural-logarithm of the pdf/pmf of each sample. Only
required/used if `use_reparametrization=False`.
+ Default value: `None`.
use_reparametrization: Python `bool` indicating that the approximation
- should use the fact that the gradient of samples is unbiased.
- axis: The dimensions to average. If `None` (the default), averages all
+ should use the fact that the gradient of samples is unbiased. Whether
+ `True` or `False`, this arg only affects the gradient of the resulting
+ `approx_expectation`.
+ Default value: `True`.
+ axis: The dimensions to average. If `None`, averages all
dimensions.
- keep_dims: If true, retains averaged dimensions with length 1.
- name: A `name_scope` for operations created by this function (optional).
- Default value: "expectation".
+ Default value: `0` (the left-most dimension).
+ keep_dims: If True, retains averaged dimensions using size `1`.
+ Default value: `False`.
+ name: A `name_scope` for operations created by this function.
+ Default value: `None` (which implies "expectation").
Returns:
approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation
of `E_p[f(X)]`.
Raises:
- ValueError: if `f` is not `callable`.
- ValueError: if `use_reparametrization=False` and `log_prob` is not
+ ValueError: if `f` is not a Python `callable`.
+ ValueError: if `use_reparametrization=False` and `log_prob` is not a Python
`callable`.
"""
@@ -273,19 +332,23 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
else:
if not callable(log_prob):
raise ValueError('`log_prob` must be a callable function.')
- x = array_ops.stop_gradient(samples)
+ stop = array_ops.stop_gradient # For readability.
+ x = stop(samples)
logpx = log_prob(x)
- # Numerically, exp(g(x) - stop[g(x)]) is always 1, even if exp(g(x)) is
- # unstable. But the gradient is also stable, ie,
- # d/dx exp(g(x) - stop[g(x)])
- # = exp(g(x) - stop[g(x)]) d/dx g(x)
- # = d/dx g(x) [numerically exact since IEEE754 has the property
- # that for any finite floating-point number x:
- # x - x == 0.0]
- return math_ops.reduce_mean(
- f(x) * math_ops.exp(logpx - array_ops.stop_gradient(logpx)),
- axis=axis,
- keep_dims=keep_dims)
+ fx = f(x) # Call `f` once in case it has side-effects.
+ # We now rewrite f(x) so that:
+ # `grad[f(x)] := grad[f(x)] + f(x) * grad[logqx]`.
+ # To achieve this, we use a trick that
+ # `h(x) - stop(h(x)) == zeros_like(h(x))`
+ # but its gradient is grad[h(x)].
+ # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
+ # this trick loses no precision. For more discussion regarding the
+ # relevant portions of the IEEE754 standard, see the StackOverflow
+ # question,
+ # "Is there a floating point value of x, for which x-x == 0 is false?"
+ # http://stackoverflow.com/q/2686644
+ fx += stop(fx) * (logpx - stop(logpx)) # Add zeros_like(logpx).
+ return math_ops.reduce_mean(fx, axis=axis, keep_dims=keep_dims)
def _sample_mean(values):