diff options
author | 2017-07-21 17:51:03 -0700 | |
---|---|---|
committer | 2017-07-21 17:55:10 -0700 | |
commit | 9c08e99a02c682acca0dcb9be9f79c3a9d38b615 (patch) | |
tree | 1e3c55bac1f1ffc6b7ce480ecbe795c9cd7304ba | |
parent | 4895cec77fac920a77a79e9e6f79951503483303 (diff) |
Implement VIMCO-like objective for approx Csiszar f-Divergence. Simplify monte_carlo.expectation_v2 calculation when doing score-trick.
PiperOrigin-RevId: 162806278
5 files changed, 544 insertions, 64 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py index dda3f60065..d06b69885c 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py @@ -37,6 +37,14 @@ from tensorflow.python.platform import test cd = csiszar_divergence_impl +def tridiag(d, diag_value, offdiag_value): + """d x d matrix with given value on diag, and one super/sub diag.""" + diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) + three_bands = array_ops.matrix_band_part( + array_ops.fill([d, d], offdiag_value), 1, 1) + return diag_mat + three_bands + + class AmariAlphaTest(test.TestCase): def setUp(self): @@ -483,14 +491,14 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): approx_kl = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_forward, - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) @@ -517,14 +525,14 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): approx_kl = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) @@ -540,33 +548,26 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.02, atol=0.) - def _tridiag(self, d, diag_value, offdiag_value): - """d x d matrix with given value on diag, and one super/sub diag.""" - diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) - three_bands = array_ops.matrix_band_part( - array_ops.fill([d, d], offdiag_value), 1, 1) - return diag_mat + three_bands - def test_kl_reverse_multidim(self): with self.test_session() as sess: d = 5 # Dimension p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5)) + covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d) approx_kl = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) @@ -588,7 +589,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): d = 5 # Dimension p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5)) + covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q @@ -597,14 +598,14 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): approx_kl = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_forward, - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p=p, + p_log_prob=p.log_prob, q=q, num_draws=int(1e5), seed=1) @@ -628,7 +629,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): seed = 1 p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5)) + covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) # Variance is very high when approximating Forward KL, so we make # scale_diag larger than in test_kl_reverse_multidim. This ensures q @@ -639,21 +640,21 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): approx_kl = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, - p=p, + p_log_prob=p.log_prob, q=q, num_draws=num_draws, seed=seed) approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p=p, + p_log_prob=p.log_prob, q=q, num_draws=num_draws, seed=seed) approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( f=cd.kl_reverse, - p=p, + p_log_prob=p.log_prob, q=q, num_draws=num_draws, use_reparametrization=False, @@ -662,7 +663,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): approx_kl_self_normalized_score_trick = ( cd.monte_carlo_csiszar_f_divergence( f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p=p, + p_log_prob=p.log_prob, q=q, num_draws=num_draws, use_reparametrization=False, @@ -670,7 +671,7 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): exact_kl = kullback_leibler.kl_divergence(q, p) - grad = lambda fs: gradients_impl.gradients(fs, s)[0] + grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] [ approx_kl_grad_, @@ -684,11 +685,11 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): approx_kl_self_normalized_score_trick_, exact_kl_, ] = sess.run([ - grad(approx_kl), - grad(approx_kl_self_normalized), - grad(approx_kl_score_trick), - grad(approx_kl_self_normalized_score_trick), - grad(exact_kl), + grad_sum(approx_kl), + grad_sum(approx_kl_self_normalized), + grad_sum(approx_kl_score_trick), + grad_sum(approx_kl_self_normalized_score_trick), + grad_sum(exact_kl), approx_kl, approx_kl_self_normalized, approx_kl_score_trick, @@ -724,5 +725,154 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): rtol=0.017, atol=0.) -if __name__ == '__main__': +class CsiszarVIMCOTest(test.TestCase): + + def _numpy_csiszar_vimco_helper(self, logu): + """Numpy implementation of `csiszar_vimco_helper`.""" + n = logu.shape[0] + u = np.exp(logu) + loogeoavg_u = [] # Leave-one-out geometric-average of exp(logu). + for j in range(n): + loogeoavg_u.append(np.exp(np.mean( + [logu[i, ...] for i in range(n) if i != j], + axis=0))) + loogeoavg_u = np.array(loogeoavg_u) + + loosum_u = [] # Leave-one-out sum of exp(logu). + for j in range(n): + loosum_u.append(np.sum( + [u[i, ...] for i in range(n) if i != j], + axis=0)) + loosum_u = np.array(loosum_u) + + # Natural log of the average u except each is swapped-out for its + # leave-`i`-th-out Geometric average. + log_sooavg_u = np.log(loosum_u + loogeoavg_u) - np.log(n) + + log_avg_u = np.log(np.mean(u, axis=0)) + return log_avg_u, log_sooavg_u + + def test_vimco_helper(self): + + with self.test_session() as sess: + logu = np.linspace(-20, 20, 100) + np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper(logu) + [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) + self.assertAllClose(np_log_avg_u, log_avg_u, + rtol=1e-2, atol=0.) + self.assertAllClose(np_log_sooavg_u, log_sooavg_u, + rtol=1e-2, atol=0.) + + def test_vimco_helper_gradient(self): + + with self.test_session(): + logu = array_ops.constant( + np.linspace(-1e2, 100., 100).reshape([50, 2])) + log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) + g = gradients_impl.gradients(log_avg_u - log_sooavg_u, logu)[0].eval() + self.assertAllEqual(np.ones_like(g, dtype=np.bool), np.isfinite(g)) + self.assertAllEqual(np.ones_like(g, dtype=np.bool), g != 0.) + + def test_vimco_and_gradient(self): + + with self.test_session() as sess: + dims = 5 # Dimension + num_draws = int(20) + num_batch_draws = int(3) + seed = 1 + + f = lambda logu: cd.kl_reverse(logu, self_normalized=False) + np_f = lambda logu: -logu + + p = mvn_full_lib.MultivariateNormalFullCovariance( + covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5)) + + # Variance is very high when approximating Forward KL, so we make + # scale_diag larger than in test_kl_reverse_multidim. This ensures q + # "covers" p and thus Var_q[p/q] is smaller. + s = array_ops.constant(1.) + q = mvn_diag_lib.MultivariateNormalDiag( + scale_diag=array_ops.tile([s], [dims])) + + vimco = cd.csiszar_vimco( + f=f, + p_log_prob=p.log_prob, + q=q, + num_draws=num_draws, + num_batch_draws=num_batch_draws, + seed=seed) + + x = q.sample(sample_shape=[num_draws, num_batch_draws], + seed=seed) + x = array_ops.stop_gradient(x) + logu = p.log_prob(x) - q.log_prob(x) + f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0]) + + grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] + + def jacobian(x): + # Warning: this function is slow and may not even finish if prod(shape) + # is larger than, say, 100. + shape = x.shape.as_list() + assert all(s is not None for s in shape) + x = array_ops.reshape(x, shape=[-1]) + r = [grad_sum(x[i]) for i in range(np.prod(shape))] + return array_ops.reshape(array_ops.stack(r), shape=shape) + + [ + logu_, + jacobian_logqx_, + vimco_, + grad_vimco_, + f_log_sum_u_, + grad_mean_f_log_sum_u_, + ] = sess.run([ + logu, + jacobian(q.log_prob(x)), + vimco, + grad_sum(vimco), + f_log_sum_u, + grad_sum(f_log_sum_u) / num_batch_draws, + ]) + + np_log_avg_u, np_log_sooavg_u = self._numpy_csiszar_vimco_helper(logu_) + + # Test VIMCO loss is correct. + self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, + rtol=1e-5, atol=0.) + + # Test gradient of VIMCO loss is correct. + # + # To make this computation we'll inject two gradients from TF: + # - grad[mean(f(log(sum(p(x)/q(x)))))] + # - jacobian[log(q(x))]. + # + # We now justify why using these (and only these) TF values for + # ground-truth does not undermine the completeness of this test. + # + # Regarding `grad_mean_f_log_sum_u_`, note that we validate the + # correctness of the zero-th order derivative (for each batch member). + # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient + # information, we can safely rely on TF. + self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.) + # + # Regarding `jacobian_logqx_`, note that testing the gradient of + # `q.log_prob` is outside the scope of this unit-test thus we may safely + # use TF to find it. + + # The `mean` is across batches and the `sum` is across iid samples. + np_grad_vimco = ( + grad_mean_f_log_sum_u_ + + np.mean( + np.sum( + jacobian_logqx_ * (np_f(np_log_avg_u) + - np_f(np_log_sooavg_u)), + axis=0), + axis=0)) + + self.assertAllClose(np_grad_vimco, grad_vimco_, + rtol=1e-5, atol=0.) + + +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index 38426ebf1f..d9e23646d8 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -28,6 +28,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test @@ -210,6 +213,93 @@ class ExpectationTest(test.TestCase): efx_score_grad_[2:-2], rtol=0.05, atol=0.) + def test_docstring_example_normal(self): + with self.test_session() as sess: + num_draws = int(1e5) + mu_p = constant_op.constant(0.) + mu_q = constant_op.constant(1.) + p = normal_lib.Normal(loc=mu_p, scale=1.) + q = normal_lib.Normal(loc=mu_q, scale=2.) + exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q) + approx_kl_normal_normal = monte_carlo_lib.expectation( + f=lambda x: p.log_prob(x) - q.log_prob(x), + samples=p.sample(num_draws, seed=42), + log_prob=p.log_prob, + use_reparametrization=(p.reparameterization_type + == distribution_lib.FULLY_REPARAMETERIZED)) + [exact_kl_normal_normal_, approx_kl_normal_normal_] = sess.run([ + exact_kl_normal_normal, approx_kl_normal_normal]) + self.assertEqual( + True, + p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED) + self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_, + rtol=0.01, atol=0.) + + # Compare gradients. (Not present in `docstring`.) + gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0] + gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0] + [ + gradp_exact_kl_normal_normal_, + gradq_exact_kl_normal_normal_, + gradp_approx_kl_normal_normal_, + gradq_approx_kl_normal_normal_, + ] = sess.run([ + gradp(exact_kl_normal_normal), + gradq(exact_kl_normal_normal), + gradp(approx_kl_normal_normal), + gradq(approx_kl_normal_normal), + ]) + self.assertAllClose(gradp_exact_kl_normal_normal_, + gradp_approx_kl_normal_normal_, + rtol=0.01, atol=0.) + self.assertAllClose(gradq_exact_kl_normal_normal_, + gradq_approx_kl_normal_normal_, + rtol=0.01, atol=0.) + + def test_docstring_example_gamma(self): + with self.test_session() as sess: + num_draws = int(1e5) + concentration_p = constant_op.constant(1.) + concentration_q = constant_op.constant(2.) + p = gamma_lib.Gamma(concentration=concentration_p, rate=1.) + q = gamma_lib.Gamma(concentration=concentration_q, rate=3.) + approx_kl_gamma_gamma = monte_carlo_lib.expectation( + f=lambda x: p.log_prob(x) - q.log_prob(x), + samples=p.sample(num_draws, seed=42), + log_prob=p.log_prob, + use_reparametrization=(p.reparameterization_type + == distribution_lib.FULLY_REPARAMETERIZED)) + exact_kl_gamma_gamma = kullback_leibler.kl_divergence(p, q) + [exact_kl_gamma_gamma_, approx_kl_gamma_gamma_] = sess.run([ + exact_kl_gamma_gamma, approx_kl_gamma_gamma]) + self.assertEqual( + False, + p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED) + self.assertAllClose(exact_kl_gamma_gamma_, approx_kl_gamma_gamma_, + rtol=0.01, atol=0.) + + # Compare gradients. (Not present in `docstring`.) + gradp = lambda fp: gradients_impl.gradients(fp, concentration_p)[0] + gradq = lambda fq: gradients_impl.gradients(fq, concentration_q)[0] + [ + gradp_exact_kl_gamma_gamma_, + gradq_exact_kl_gamma_gamma_, + gradp_approx_kl_gamma_gamma_, + gradq_approx_kl_gamma_gamma_, + ] = sess.run([ + gradp(exact_kl_gamma_gamma), + gradq(exact_kl_gamma_gamma), + gradp(approx_kl_gamma_gamma), + gradq(approx_kl_gamma_gamma), + ]) + # Notice that variance (i.e., `rtol`) is higher when using score-trick. + self.assertAllClose(gradp_exact_kl_gamma_gamma_, + gradp_approx_kl_gamma_gamma_, + rtol=0.05, atol=0.) + self.assertAllClose(gradq_exact_kl_gamma_gamma_, + gradq_approx_kl_gamma_gamma_, + rtol=0.03, atol=0.) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py index b1bcc86022..9f7a95f138 100644 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py +++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py @@ -31,6 +31,7 @@ _allowed_symbols = [ 'amari_alpha', 'arithmetic_geometric', 'chi_square', + 'csiszar_vimco', 'dual_csiszar_function', 'jeffreys', 'jensen_shannon', diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py index 247d434bb7..54900ab893 100644 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py @@ -17,6 +17,7 @@ @@amari_alpha @@arithmetic_geometric @@chi_square +@@csiszar_vimco @@dual_csiszar_function @@jeffreys @@jensen_shannon @@ -46,6 +47,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util def amari_alpha(logu, alpha=1., self_normalized=False, name=None): @@ -696,7 +698,7 @@ def dual_csiszar_function(logu, csiszar_function, name=None): Args: logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python callable representing a Csiszar-function over + csiszar_function: Python `callable` representing a Csiszar-function over log-domain. name: Python `str` name prefixed to Ops created by this function. @@ -765,7 +767,7 @@ def symmetrized_csiszar_function(logu, csiszar_function, name=None): Args: logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python callable representing a Csiszar-function over + csiszar_function: Python `callable` representing a Csiszar-function over log-domain. name: Python `str` name prefixed to Ops created by this function. @@ -781,7 +783,13 @@ def symmetrized_csiszar_function(logu, csiszar_function, name=None): def monte_carlo_csiszar_f_divergence( - f, p, q, num_draws, use_reparametrization=True, seed=None, name=None): + f, + p_log_prob, + q, + num_draws, + use_reparametrization=None, + seed=None, + name=None): """Monte-Carlo approximation of the Csiszar f-Divergence. A Csiszar-function is a member of, @@ -843,15 +851,22 @@ def monte_carlo_csiszar_f_divergence( "Evidence Divergence Bound Optimization" (EDBO). Args: - f: Python callable representing a Csiszar-function in log-space. - p: `tf.Distribution`-like instance; must implement `log_prob(x)`. + f: Python `callable` representing a Csiszar-function in log-space, i.e., + takes `p_log_prob(q_samples) - q.log_prob(q_samples)`. + p_log_prob: Python `callable` taking (a batch of) samples from `q` and + returning the the natural-log of the probability under distribution `p`. + (In variational inference `p` is the joint distribution.) q: `tf.Distribution`-like instance; must implement: - `reparameterization_type`, `sample(n)`, and `log_prob(x)`. + `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`. + (In variational inference `q` is the approximate posterior distribution.) num_draws: Integer scalar number of draws used to approximate the f-Divergence expectation. - use_reparametrization: Python `bool`. When `True` uses the standard - Monte-Carlo average. When `False` uses the score-gradient trick. (See - above for details.) + use_reparametrization: Python `bool`. When `None` (the default), + automatically set to: + `q.reparameterization_type == distribution.FULLY_REPARAMETERIZED`. + When `True` uses the standard Monte-Carlo average. When `False` uses the + score-gradient trick. (See above for details.) When `False`, consider + using `csiszar_vimco`. seed: Python `int` seed for `q.sample`. name: Python `str` name prefixed to Ops created by this function. @@ -866,18 +881,179 @@ def monte_carlo_csiszar_f_divergence( samples of another distribution which does not depend on the parameterization of `q`. This property ensures the gradient (with respect to parameters) is valid. + TypeError: if `p_log_prob` is not a Python `callable`. """ with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - if (use_reparametrization and - q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): + if use_reparametrization is None: + use_reparametrization = (q.reparameterization_type + == distribution.FULLY_REPARAMETERIZED) + elif (use_reparametrization and + q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): # TODO(jvdillon): Consider only raising an exception if the gradient is # requested. raise ValueError( "Distribution `q` must be reparameterized, i.e., a diffeomorphic " "transformation of a parameterless distribution. (Otherwise this " "function has a biased gradient.)") + if not callable(p_log_prob): + raise TypeError("`p_log_prob` must be a Python `callable` function.") return monte_carlo.expectation( - f=lambda x: f(p.log_prob(x) - q.log_prob(x)), + f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)), samples=q.sample(num_draws, seed=seed), - log_prob=q.log_prob, + log_prob=q.log_prob, # Only used if use_reparametrization=False. use_reparametrization=use_reparametrization) + + +def csiszar_vimco(f, + p_log_prob, + q, + num_draws, + num_batch_draws=1, + seed=None, + name=None): + """Use VIMCO to lower the variance of gradient[csiszar_function(Avg(logu))]. + + This function generalizes "Variational Inference for Monte Carlo Objectives" + (VIMCO), i.e., https://arxiv.org/abs/1602.06725, to Csiszar f-Divergences. + + Note: if `q.reparameterization_type = distribution.FULLY_REPARAMETERIZED`, + consider using `monte_carlo_csiszar_f_divergence`. + + The VIMCO loss is: + + ```none + vimco = f(Avg{logu[i] : i=0,...,m-1}) + where, + logu[i] = log( p(x, h[i]) / q(h[i] | x) ) + h[i] iid~ q(H | x) + ``` + + Interestingly, the VIMCO gradient is not the naive gradient of `vimco`. + Rather, it is characterized by: + + ```none + grad[vimco] - variance_reducing_term + where, + variance_reducing_term = Sum{ grad[log q(h[i] | x)] * + (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) + : i=0, ..., m-1 } + h[j;i] = { u[j] j!=i + { GeometricAverage{ u[k] : k!=i} j==i + ``` + + (We omitted `stop_gradient` for brevity. See implementation for more details.) + + The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th + element has been replaced by the leave-`i`-out Geometric-average. + + Args: + f: Python `callable` representing a Csiszar-function in log-space. + p_log_prob: Python `callable` representing the natural-log of the + probability under distribution `p`. (In variational inference `p` is the + joint distribution.) + q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and + `log_prob(x)`. (In variational inference `q` is the approximate posterior + distribution.) + num_draws: Integer scalar number of draws used to approximate the + f-Divergence expectation. + num_batch_draws: Integer scalar number of draws used to approximate the + f-Divergence expectation. + seed: Python `int` seed for `q.sample`. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + vimco: The Csiszar f-Divergence generalized VIMCO objective. + + Raises: + ValueError: if `num_draws < 2`. + """ + with ops.name_scope(name, "csiszar_vimco", [num_draws, num_batch_draws]): + if num_draws < 2: + raise ValueError("Must specify num_draws > 1.") + stop = array_ops.stop_gradient # For readability. + x = stop(q.sample(sample_shape=[num_draws, num_batch_draws], + seed=seed)) + logqx = q.log_prob(x) + logu = p_log_prob(x) - logqx + f_log_avg_u, f_log_sooavg_u = [f(r) for r in csiszar_vimco_helper(logu)] + dotprod = math_ops.reduce_sum( + logqx * stop(f_log_avg_u - f_log_sooavg_u), + axis=0) # Sum over iid samples. + # We now rewrite f_log_avg_u so that: + # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`. + # To achieve this, we use a trick that + # `f(x) - stop(f(x)) == zeros_like(f(x))` + # but its gradient is grad[f(x)]. + # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence + # this trick loses no precision. For more discussion regarding the relevant + # portions of the IEEE754 standard, see the StackOverflow question, + # "Is there a floating point value of x, for which x-x == 0 is false?" + # http://stackoverflow.com/q/2686644 + f_log_avg_u += dotprod - stop(dotprod) # Add zeros_like(dot_prod). + return math_ops.reduce_mean(f_log_avg_u, axis=0) # Avg over batches. + + +def csiszar_vimco_helper(logu, name=None): + """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`. + + `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e., + + ```none + logu[j] = log(u[j]) + u[j] = p(x, h[j]) / q(h[j] | x) + h[j] iid~ q(H | x) + ``` + + Args: + logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the + average of `u`. + log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the + average of `u`` except that the average swaps-out `u[i]` for the + leave-`i`-out Geometric-average, i.e., + + ```none + log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1}) + h[j ; i] = { u[j] j!=i + { GeometricAverage{u[k] : k != i} j==i + ``` + + """ + with ops.name_scope(name, "csiszar_vimco_helper", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + + n = logu.shape.with_rank_at_least(1)[0].value + if n is None: + n = array_ops.shape(logu)[0] + log_n = math_ops.log(math_ops.cast(n, dtype=logu.dtype)) + nm1 = math_ops.cast(n - 1, dtype=logu.dtype) + else: + log_n = np.log(n).astype(logu.dtype.as_numpy_dtype) + nm1 = np.asarray(n - 1, dtype=logu.dtype.as_numpy_dtype) + + # Throughout we reduce across axis=0 since this is presumed to be iid + # samples. + + log_sum_u = math_ops.reduce_logsumexp(logu, axis=0) + + # log_loosum_u[i] = + # = logsumexp(logu[j] : j != i) + # = log( exp(logsumexp(logu)) - exp(logu[i]) ) + # = log( exp(logsumexp(logu - logu[i])) exp(logu[i]) - exp(logu[i])) + # = logu[i] + log(exp(logsumexp(logu - logu[i])) - 1) + # = logu[i] + softplus_inverse(logsumexp(logu - logu[i])) + log_loosum_u = logu + distribution_util.softplus_inverse(log_sum_u - logu) + + # The swap-one-out-sum ("soosum") is n different sums, each of which + # replaces the i-th item with the i-th-left-out average, i.e., + # soo_sum_u[i] = [exp(logu) - exp(logu[i])] + exp(mean(logu[!=i])) + # = exp(log_loosum_u[i]) + exp(looavg_logu[i]) + looavg_logu = (math_ops.reduce_sum(logu, axis=0) - logu) / nm1 + log_soosum_u = math_ops.reduce_logsumexp( + array_ops.stack([log_loosum_u, looavg_logu]), + axis=0) + + return log_sum_u - log_n, log_soosum_u - log_n diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 0dae74d31f..985177e897 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -220,9 +220,10 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, `S_n = Avg{s_i}` and `s_i = f(x_i), x_i ~ p`. However, if p is not reparameterized, TensorFlow's gradient will be incorrect - since the chain-rule stops at samples of unreparameterized distributions. In - this circumstance using the Score-Gradient trick results in an unbiased - gradient, i.e., + since the chain-rule stops at samples of non-reparameterized distributions. + (The non-differentiated result, `approx_expectation`, is the same regardless + of `use_reparametrization`.) In this circumstance using the Score-Gradient + trick results in an unbiased gradient, i.e., ```none grad[ E_p[f(X)] ] @@ -240,6 +241,58 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Warning: users are responsible for verifying `p` is a "reparameterized" distribution. + Example Use: + + ```python + bf = tf.contrib.bayesflow + ds = tf.contrib.distributions + + # Monte-Carlo approximation of a reparameterized distribution, e.g., Normal. + + num_draws = int(1e5) + p = ds.Normal(loc=0., scale=1.) + q = ds.Normal(loc=1., scale=2.) + exact_kl_normal_normal = ds.kl_divergence(p, q) + # ==> 0.44314718 + approx_kl_normal_normal = bf.expectation( + f=lambda x: p.log_prob(x) - q.log_prob(x), + samples=p.sample(num_draws, seed=42), + log_prob=p.log_prob, + use_reparametrization=(p.reparameterization_type + == distribution.FULLY_REPARAMETERIZED)) + # ==> 0.44632751 + # Relative Error: <1% + + # Monte-Carlo approximation of non-reparameterized distribution, e.g., Gamma. + + num_draws = int(1e5) + p = ds.Gamma(concentration=1., rate=1.) + q = ds.Gamma(concentration=2., rate=3.) + exact_kl_gamma_gamma = ds.kl_divergence(p, q) + # ==> 0.37999129 + approx_kl_gamma_gamma = bf.expectation( + f=lambda x: p.log_prob(x) - q.log_prob(x), + samples=p.sample(num_draws, seed=42), + log_prob=p.log_prob, + use_reparametrization=(p.reparameterization_type + == distribution.FULLY_REPARAMETERIZED)) + # ==> 0.37696719 + # Relative Error: <1% + + # For comparing the gradients, see `monte_carlo_test.py`. + ``` + + Note: The above example is for illustration only. To compute approximate + KL-divergence, the following is preferred: + + ```python + approx_kl_p_q = bf.monte_carlo_csiszar_f_divergence( + f=bf.kl_reverse, + p_log_prob=q.log_prob, + q=p, + num_draws=num_draws) + ``` + Args: f: Python callable which can return `f(samples)`. samples: `Tensor` of samples used to form the Monte-Carlo approximation of @@ -247,21 +300,27 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, log_prob: Python callable which can return `log_prob(samples)`. Must correspond to the natural-logarithm of the pdf/pmf of each sample. Only required/used if `use_reparametrization=False`. + Default value: `None`. use_reparametrization: Python `bool` indicating that the approximation - should use the fact that the gradient of samples is unbiased. - axis: The dimensions to average. If `None` (the default), averages all + should use the fact that the gradient of samples is unbiased. Whether + `True` or `False`, this arg only affects the gradient of the resulting + `approx_expectation`. + Default value: `True`. + axis: The dimensions to average. If `None`, averages all dimensions. - keep_dims: If true, retains averaged dimensions with length 1. - name: A `name_scope` for operations created by this function (optional). - Default value: "expectation". + Default value: `0` (the left-most dimension). + keep_dims: If True, retains averaged dimensions using size `1`. + Default value: `False`. + name: A `name_scope` for operations created by this function. + Default value: `None` (which implies "expectation"). Returns: approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation of `E_p[f(X)]`. Raises: - ValueError: if `f` is not `callable`. - ValueError: if `use_reparametrization=False` and `log_prob` is not + ValueError: if `f` is not a Python `callable`. + ValueError: if `use_reparametrization=False` and `log_prob` is not a Python `callable`. """ @@ -273,19 +332,23 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, else: if not callable(log_prob): raise ValueError('`log_prob` must be a callable function.') - x = array_ops.stop_gradient(samples) + stop = array_ops.stop_gradient # For readability. + x = stop(samples) logpx = log_prob(x) - # Numerically, exp(g(x) - stop[g(x)]) is always 1, even if exp(g(x)) is - # unstable. But the gradient is also stable, ie, - # d/dx exp(g(x) - stop[g(x)]) - # = exp(g(x) - stop[g(x)]) d/dx g(x) - # = d/dx g(x) [numerically exact since IEEE754 has the property - # that for any finite floating-point number x: - # x - x == 0.0] - return math_ops.reduce_mean( - f(x) * math_ops.exp(logpx - array_ops.stop_gradient(logpx)), - axis=axis, - keep_dims=keep_dims) + fx = f(x) # Call `f` once in case it has side-effects. + # We now rewrite f(x) so that: + # `grad[f(x)] := grad[f(x)] + f(x) * grad[logqx]`. + # To achieve this, we use a trick that + # `h(x) - stop(h(x)) == zeros_like(h(x))` + # but its gradient is grad[h(x)]. + # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence + # this trick loses no precision. For more discussion regarding the + # relevant portions of the IEEE754 standard, see the StackOverflow + # question, + # "Is there a floating point value of x, for which x-x == 0 is false?" + # http://stackoverflow.com/q/2686644 + fx += stop(fx) * (logpx - stop(logpx)) # Add zeros_like(logpx). + return math_ops.reduce_mean(fx, axis=axis, keep_dims=keep_dims) def _sample_mean(values): |