aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-06-20 15:29:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-20 15:32:41 -0700
commite74c7ad4bb246db8e0b011d995b558b66a6fc414 (patch)
treed1cf290ccab36b4aa8447507a56b8330529bd739
parentaf5ad5f7d7930af90bf9b04fb475d6b3ad7604ea (diff)
Add Score-Gradient trick to `monte_carlo_csiszar_f_divergence`.
PiperOrigin-RevId: 159623845
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py83
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py85
2 files changed, 151 insertions, 17 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
index c120acaefc..fabf7a9b77 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@@ -573,6 +574,88 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
rtol=0.05, atol=0.)
+ def test_score_trick(self):
+
+ with self.test_session() as sess:
+ d = 5 # Dimension
+ num_draws = int(1e5)
+ seed = 1
+
+ p = mvn_full_lib.MultivariateNormalFullCovariance(
+ covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5))
+
+ # Variance is very high when approximating Forward KL, so we make
+ # scale_diag larger than in test_kl_reverse_multidim. This ensures q
+ # "covers" p and thus Var_q[p/q] is smaller.
+ s = array_ops.constant(1.)
+ q = mvn_diag_lib.MultivariateNormalDiag(
+ scale_diag=array_ops.tile([s], [d]))
+
+ approx_kl = cd.monte_carlo_csiszar_f_divergence(
+ f=cd.kl_reverse,
+ p=p,
+ q=q,
+ num_draws=num_draws,
+ seed=seed)
+
+ approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence(
+ f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
+ p=p,
+ q=q,
+ num_draws=num_draws,
+ seed=seed)
+
+ approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence(
+ f=cd.kl_reverse,
+ p=p,
+ q=q,
+ num_draws=num_draws,
+ use_reparametrization=False,
+ seed=seed)
+
+ approx_kl_self_normalized_score_trick = (
+ cd.monte_carlo_csiszar_f_divergence(
+ f=lambda logu: cd.kl_reverse(logu, self_normalized=True),
+ p=p,
+ q=q,
+ num_draws=num_draws,
+ use_reparametrization=False,
+ seed=seed))
+
+ exact_kl = kullback_leibler.kl_divergence(q, p)
+
+ grad = lambda fs: gradients_impl.gradients(fs, s)[0]
+
+ [
+ approx_kl_,
+ approx_kl_self_normalized_,
+ approx_kl_score_trick_,
+ approx_kl_self_normalized_score_trick_,
+ exact_kl_,
+ ] = sess.run([
+ grad(approx_kl),
+ grad(approx_kl_self_normalized),
+ grad(approx_kl_score_trick),
+ grad(approx_kl_self_normalized_score_trick),
+ grad(exact_kl),
+ ])
+
+ self.assertAllClose(
+ approx_kl_, exact_kl_,
+ rtol=0.06, atol=0.)
+
+ self.assertAllClose(
+ approx_kl_self_normalized_, exact_kl_,
+ rtol=0.05, atol=0.)
+
+ self.assertAllClose(
+ approx_kl_score_trick_, exact_kl_,
+ rtol=0.06, atol=0.)
+
+ self.assertAllClose(
+ approx_kl_self_normalized_score_trick_, exact_kl_,
+ rtol=0.05, atol=0.)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
index 262da41bda..7b51d8d932 100644
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
@@ -41,6 +41,7 @@ import numpy as np
from tensorflow.contrib import framework as contrib_framework
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import distribution
@@ -734,7 +735,8 @@ def symmetrized_csiszar_function(logu, csiszar_function, name=None):
+ dual_csiszar_function(logu, csiszar_function))
-def monte_carlo_csiszar_f_divergence(f, p, q, num_draws, seed=None, name=None):
+def monte_carlo_csiszar_f_divergence(
+ f, p, q, num_draws, use_reparametrization=True, seed=None, name=None):
"""Monte-Carlo approximation of the Csiszar f-Divergence.
A Csiszar-function is a member of,
@@ -751,6 +753,38 @@ def monte_carlo_csiszar_f_divergence(f, p, q, num_draws, seed=None, name=None):
where x_j ~iid q(x)
```
+ Tricks: Reparameterization and Score-Gradient
+
+ When q is "reparameterized", i.e., a diffeomorphic transformation of a
+ parameterless distribution (e.g.,
+ `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
+ expectation, i.e.,
+ `nabla Avg{ s_i : i=1...n } = Avg{ nabla s_i : i=1...n }` where `S_n=Avg{s_i}`
+ and `s_i = f(x_i), x_i ~ q`.
+
+ However, if q is not reparameterized, TensorFlow's gradient will be incorrect
+ since the chain-rule stops at samples of unreparameterized distributions. In
+ this circumstance using the Score-Gradient trick results in an unbiased
+ gradient, i.e.,
+
+ ```none
+ nabla E_q[f(X)]
+ = nabla int dx q(x) f(x)
+ = int dx nabla [ q(x) f(x) ]
+ = int dx q'(x) f(x) + q(x) f'(x)
+ = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
+ = int dx q(x) nabla [ log(q(x)) stopgrad[f(x)] + f(x) ]
+ = E_q[ nabla [ log(q(X)) stopgrad[f(X)] + f(X) ] ]
+ ~= Avg{ log(q(y_i)) stopgrad[f(y_i)] + f(y_i) : y_i = stopgrad[x_i], x_i ~ q}
+ ```
+
+ Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is
+ usually preferable to `use_reparametrization = True`.
+
+ Warning: using `use_reparametrization = False` will mean that the result is
+ *not* the Csiszar f-Divergence. However its expected gradient *is* the
+ gradient of the Csiszar f-Divergence.
+
Example Application:
The Csiszar f-Divergence is a useful framework for variational inference.
@@ -775,29 +809,46 @@ def monte_carlo_csiszar_f_divergence(f, p, q, num_draws, seed=None, name=None):
`reparameterization_type`, `sample(n)`, and `log_prob(x)`.
num_draws: Integer scalar number of draws used to approximate the
f-Divergence expectation.
+ use_reparametrization: Python `bool`. When `True` uses the standard
+ Monte-Carlo average. When `False` uses the score-gradient trick. (See
+ above for details.)
seed: Python `int` seed for `q.sample`.
name: Python `str` name prefixed to Ops created by this function.
Returns:
monte_carlo_csiszar_f_divergence: Floating-type `Tensor` Monte Carlo
- approximation of the Csiszar f-Divergence.
+ approximation of the Csiszar f-Divergence. Warning: using
+ `use_reparametrization = False` will mean that the result is *not* the
+ Csiszar f-Divergence. However its expected gradient *is* the actual
+ gradient of the Csiszar f-Divergence.
Raises:
- ValueError: if `q` is not a reparameterized distribution. A distribution `q`
- is said to be "reparameterized" when its samples are generated by
- transforming the samples of another distribution which does not depend on
- the parameterization of `q`. This property ensures the gradient (with
- respect to parameters) is valid.
+ ValueError: if `q` is not a reparameterized distribution and
+ `use_reparametrization = True`. A distribution `q` is said to be
+ "reparameterized" when its samples are generated by transforming the
+ samples of another distribution which does not depend on the
+ parameterization of `q`. This property ensures the gradient (with respect
+ to parameters) is valid.
"""
with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]):
- # TODO(jvdillon): Consider only raising an exception if the gradient is
- # requested.
- if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED:
- raise ValueError(
- "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
- "transformation of a parameterless distribution. (Otherwise this "
- "function has a biased gradient.)")
x = q.sample(num_draws, seed=seed)
- return math_ops.reduce_mean(
- f(p.log_prob(x) - q.log_prob(x)),
- axis=0)
+ if use_reparametrization:
+ # TODO(jvdillon): Consider only raising an exception if the gradient is
+ # requested.
+ if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED:
+ raise ValueError(
+ "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
+ "transformation of a parameterless distribution. (Otherwise this "
+ "function has a biased gradient.)")
+ return math_ops.reduce_mean(f(p.log_prob(x) - q.log_prob(x)), axis=0)
+ else:
+ x = array_ops.stop_gradient(x)
+ logqx = q.log_prob(x)
+ fx = f(p.log_prob(x) - logqx)
+ # Alternatively we could have returned:
+ # reduce_mean(fx * exp(logqx) / stop_gradient(exp(logqx)), axis=0)
+ # This is nice because it means the result is exactly the Csiszar
+ # f-Divergence yet the gradient is unbiased. However its numerically
+ # unstable since the q is not in log-domain.
+ return math_ops.reduce_mean(logqx * array_ops.stop_gradient(fx) + fx,
+ axis=0)