From e74c7ad4bb246db8e0b011d995b558b66a6fc414 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Tue, 20 Jun 2017 15:29:29 -0700 Subject: Add Score-Gradient trick to `monte_carlo_csiszar_f_divergence`. PiperOrigin-RevId: 159623845 --- .../python/kernel_tests/csiszar_divergence_test.py | 83 +++++++++++++++++++++ .../python/ops/csiszar_divergence_impl.py | 85 +++++++++++++++++----- 2 files changed, 151 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py index c120acaefc..fabf7a9b77 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -573,6 +574,88 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): self.assertAllClose(approx_kl_self_normalized_, exact_kl_, rtol=0.05, atol=0.) + def test_score_trick(self): + + with self.test_session() as sess: + d = 5 # Dimension + num_draws = int(1e5) + seed = 1 + + p = mvn_full_lib.MultivariateNormalFullCovariance( + covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5)) + + # Variance is very high when approximating Forward KL, so we make + # scale_diag larger than in test_kl_reverse_multidim. This ensures q + # "covers" p and thus Var_q[p/q] is smaller. + s = array_ops.constant(1.) + q = mvn_diag_lib.MultivariateNormalDiag( + scale_diag=array_ops.tile([s], [d])) + + approx_kl = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_reverse, + p=p, + q=q, + num_draws=num_draws, + seed=seed) + + approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_reverse(logu, self_normalized=True), + p=p, + q=q, + num_draws=num_draws, + seed=seed) + + approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_reverse, + p=p, + q=q, + num_draws=num_draws, + use_reparametrization=False, + seed=seed) + + approx_kl_self_normalized_score_trick = ( + cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_reverse(logu, self_normalized=True), + p=p, + q=q, + num_draws=num_draws, + use_reparametrization=False, + seed=seed)) + + exact_kl = kullback_leibler.kl_divergence(q, p) + + grad = lambda fs: gradients_impl.gradients(fs, s)[0] + + [ + approx_kl_, + approx_kl_self_normalized_, + approx_kl_score_trick_, + approx_kl_self_normalized_score_trick_, + exact_kl_, + ] = sess.run([ + grad(approx_kl), + grad(approx_kl_self_normalized), + grad(approx_kl_score_trick), + grad(approx_kl_self_normalized_score_trick), + grad(exact_kl), + ]) + + self.assertAllClose( + approx_kl_, exact_kl_, + rtol=0.06, atol=0.) + + self.assertAllClose( + approx_kl_self_normalized_, exact_kl_, + rtol=0.05, atol=0.) + + self.assertAllClose( + approx_kl_score_trick_, exact_kl_, + rtol=0.06, atol=0.) + + self.assertAllClose( + approx_kl_self_normalized_score_trick_, exact_kl_, + rtol=0.05, atol=0.) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py index 262da41bda..7b51d8d932 100644 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py @@ -41,6 +41,7 @@ import numpy as np from tensorflow.contrib import framework as contrib_framework from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution @@ -734,7 +735,8 @@ def symmetrized_csiszar_function(logu, csiszar_function, name=None): + dual_csiszar_function(logu, csiszar_function)) -def monte_carlo_csiszar_f_divergence(f, p, q, num_draws, seed=None, name=None): +def monte_carlo_csiszar_f_divergence( + f, p, q, num_draws, use_reparametrization=True, seed=None, name=None): """Monte-Carlo approximation of the Csiszar f-Divergence. A Csiszar-function is a member of, @@ -751,6 +753,38 @@ def monte_carlo_csiszar_f_divergence(f, p, q, num_draws, seed=None, name=None): where x_j ~iid q(x) ``` + Tricks: Reparameterization and Score-Gradient + + When q is "reparameterized", i.e., a diffeomorphic transformation of a + parameterless distribution (e.g., + `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and + expectation, i.e., + `nabla Avg{ s_i : i=1...n } = Avg{ nabla s_i : i=1...n }` where `S_n=Avg{s_i}` + and `s_i = f(x_i), x_i ~ q`. + + However, if q is not reparameterized, TensorFlow's gradient will be incorrect + since the chain-rule stops at samples of unreparameterized distributions. In + this circumstance using the Score-Gradient trick results in an unbiased + gradient, i.e., + + ```none + nabla E_q[f(X)] + = nabla int dx q(x) f(x) + = int dx nabla [ q(x) f(x) ] + = int dx q'(x) f(x) + q(x) f'(x) + = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] + = int dx q(x) nabla [ log(q(x)) stopgrad[f(x)] + f(x) ] + = E_q[ nabla [ log(q(X)) stopgrad[f(X)] + f(X) ] ] + ~= Avg{ log(q(y_i)) stopgrad[f(y_i)] + f(y_i) : y_i = stopgrad[x_i], x_i ~ q} + ``` + + Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is + usually preferable to `use_reparametrization = True`. + + Warning: using `use_reparametrization = False` will mean that the result is + *not* the Csiszar f-Divergence. However its expected gradient *is* the + gradient of the Csiszar f-Divergence. + Example Application: The Csiszar f-Divergence is a useful framework for variational inference. @@ -775,29 +809,46 @@ def monte_carlo_csiszar_f_divergence(f, p, q, num_draws, seed=None, name=None): `reparameterization_type`, `sample(n)`, and `log_prob(x)`. num_draws: Integer scalar number of draws used to approximate the f-Divergence expectation. + use_reparametrization: Python `bool`. When `True` uses the standard + Monte-Carlo average. When `False` uses the score-gradient trick. (See + above for details.) seed: Python `int` seed for `q.sample`. name: Python `str` name prefixed to Ops created by this function. Returns: monte_carlo_csiszar_f_divergence: Floating-type `Tensor` Monte Carlo - approximation of the Csiszar f-Divergence. + approximation of the Csiszar f-Divergence. Warning: using + `use_reparametrization = False` will mean that the result is *not* the + Csiszar f-Divergence. However its expected gradient *is* the actual + gradient of the Csiszar f-Divergence. Raises: - ValueError: if `q` is not a reparameterized distribution. A distribution `q` - is said to be "reparameterized" when its samples are generated by - transforming the samples of another distribution which does not depend on - the parameterization of `q`. This property ensures the gradient (with - respect to parameters) is valid. + ValueError: if `q` is not a reparameterized distribution and + `use_reparametrization = True`. A distribution `q` is said to be + "reparameterized" when its samples are generated by transforming the + samples of another distribution which does not depend on the + parameterization of `q`. This property ensures the gradient (with respect + to parameters) is valid. """ with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - # TODO(jvdillon): Consider only raising an exception if the gradient is - # requested. - if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED: - raise ValueError( - "Distribution `q` must be reparameterized, i.e., a diffeomorphic " - "transformation of a parameterless distribution. (Otherwise this " - "function has a biased gradient.)") x = q.sample(num_draws, seed=seed) - return math_ops.reduce_mean( - f(p.log_prob(x) - q.log_prob(x)), - axis=0) + if use_reparametrization: + # TODO(jvdillon): Consider only raising an exception if the gradient is + # requested. + if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED: + raise ValueError( + "Distribution `q` must be reparameterized, i.e., a diffeomorphic " + "transformation of a parameterless distribution. (Otherwise this " + "function has a biased gradient.)") + return math_ops.reduce_mean(f(p.log_prob(x) - q.log_prob(x)), axis=0) + else: + x = array_ops.stop_gradient(x) + logqx = q.log_prob(x) + fx = f(p.log_prob(x) - logqx) + # Alternatively we could have returned: + # reduce_mean(fx * exp(logqx) / stop_gradient(exp(logqx)), axis=0) + # This is nice because it means the result is exactly the Csiszar + # f-Divergence yet the gradient is unbiased. However its numerically + # unstable since the q is not in log-domain. + return math_ops.reduce_mean(logqx * array_ops.stop_gradient(fx) + fx, + axis=0) -- cgit v1.2.3