aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-06-22 16:50:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-22 16:55:03 -0700
commitb85601b95eba28605d3de076fa70cabf2f2e32b9 (patch)
tree59feb5dc4af4a8bb481aef0d4e1d17632713f1bc
parent07678fef5510d4a7c89d28b222ce72df49456a97 (diff)
Improve score-trick to be a valid Csiszar f-Divergence yet numerically stable.
PiperOrigin-RevId: 159896013
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py43
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py62
2 files changed, 55 insertions, 50 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
index fabf7a9b77..fba0cc6522 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py
@@ -627,6 +627,11 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
grad = lambda fs: gradients_impl.gradients(fs, s)[0]
[
+ approx_kl_grad_,
+ approx_kl_self_normalized_grad_,
+ approx_kl_score_trick_grad_,
+ approx_kl_self_normalized_score_trick_grad_,
+ exact_kl_grad_,
approx_kl_,
approx_kl_self_normalized_,
approx_kl_score_trick_,
@@ -638,23 +643,39 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase):
grad(approx_kl_score_trick),
grad(approx_kl_self_normalized_score_trick),
grad(exact_kl),
+ approx_kl,
+ approx_kl_self_normalized,
+ approx_kl_score_trick,
+ approx_kl_self_normalized_score_trick,
+ exact_kl,
])
- self.assertAllClose(
- approx_kl_, exact_kl_,
- rtol=0.06, atol=0.)
+ # Test average divergence.
+ self.assertAllClose(approx_kl_, exact_kl_,
+ rtol=0.02, atol=0.)
- self.assertAllClose(
- approx_kl_self_normalized_, exact_kl_,
- rtol=0.05, atol=0.)
+ self.assertAllClose(approx_kl_self_normalized_, exact_kl_,
+ rtol=0.08, atol=0.)
- self.assertAllClose(
- approx_kl_score_trick_, exact_kl_,
- rtol=0.06, atol=0.)
+ self.assertAllClose(approx_kl_score_trick_, exact_kl_,
+ rtol=0.02, atol=0.)
+
+ self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_,
+ rtol=0.08, atol=0.)
+
+ # Test average gradient-divergence.
+ self.assertAllClose(approx_kl_grad_, exact_kl_grad_,
+ rtol=0.007, atol=0.)
+
+ self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_,
+ rtol=0.011, atol=0.)
+
+ self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_,
+ rtol=0.018, atol=0.)
self.assertAllClose(
- approx_kl_self_normalized_score_trick_, exact_kl_,
- rtol=0.05, atol=0.)
+ approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_,
+ rtol=0.017, atol=0.)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
index 7b51d8d932..09389e5d38 100644
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
@@ -40,8 +40,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import distribution
@@ -750,7 +750,7 @@ def monte_carlo_csiszar_f_divergence(
```none
D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
- where x_j ~iid q(x)
+ where x_j ~iid q(X)
```
Tricks: Reparameterization and Score-Gradient
@@ -759,8 +759,8 @@ def monte_carlo_csiszar_f_divergence(
parameterless distribution (e.g.,
`Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
expectation, i.e.,
- `nabla Avg{ s_i : i=1...n } = Avg{ nabla s_i : i=1...n }` where `S_n=Avg{s_i}`
- and `s_i = f(x_i), x_i ~ q`.
+ `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}`
+ and `s_i = f(x_i), x_i ~iid q(X)`.
However, if q is not reparameterized, TensorFlow's gradient will be incorrect
since the chain-rule stops at samples of unreparameterized distributions. In
@@ -768,22 +768,17 @@ def monte_carlo_csiszar_f_divergence(
gradient, i.e.,
```none
- nabla E_q[f(X)]
- = nabla int dx q(x) f(x)
- = int dx nabla [ q(x) f(x) ]
- = int dx q'(x) f(x) + q(x) f'(x)
+ grad[ E_q[f(X)] ]
+ = grad[ int dx q(x) f(x) ]
+ = int dx grad[ q(x) f(x) ]
+ = int dx [ q'(x) f(x) + q(x) f'(x) ]
= int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
- = int dx q(x) nabla [ log(q(x)) stopgrad[f(x)] + f(x) ]
- = E_q[ nabla [ log(q(X)) stopgrad[f(X)] + f(X) ] ]
- ~= Avg{ log(q(y_i)) stopgrad[f(y_i)] + f(y_i) : y_i = stopgrad[x_i], x_i ~ q}
+ = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
+ = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
```
Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is
- usually preferable to `use_reparametrization = True`.
-
- Warning: using `use_reparametrization = False` will mean that the result is
- *not* the Csiszar f-Divergence. However its expected gradient *is* the
- gradient of the Csiszar f-Divergence.
+ usually preferable to set `use_reparametrization = True`.
Example Application:
@@ -817,10 +812,7 @@ def monte_carlo_csiszar_f_divergence(
Returns:
monte_carlo_csiszar_f_divergence: Floating-type `Tensor` Monte Carlo
- approximation of the Csiszar f-Divergence. Warning: using
- `use_reparametrization = False` will mean that the result is *not* the
- Csiszar f-Divergence. However its expected gradient *is* the actual
- gradient of the Csiszar f-Divergence.
+ approximation of the Csiszar f-Divergence.
Raises:
ValueError: if `q` is not a reparameterized distribution and
@@ -831,24 +823,16 @@ def monte_carlo_csiszar_f_divergence(
to parameters) is valid.
"""
with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]):
- x = q.sample(num_draws, seed=seed)
- if use_reparametrization:
+ if (use_reparametrization and
+ q.reparameterization_type != distribution.FULLY_REPARAMETERIZED):
# TODO(jvdillon): Consider only raising an exception if the gradient is
# requested.
- if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED:
- raise ValueError(
- "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
- "transformation of a parameterless distribution. (Otherwise this "
- "function has a biased gradient.)")
- return math_ops.reduce_mean(f(p.log_prob(x) - q.log_prob(x)), axis=0)
- else:
- x = array_ops.stop_gradient(x)
- logqx = q.log_prob(x)
- fx = f(p.log_prob(x) - logqx)
- # Alternatively we could have returned:
- # reduce_mean(fx * exp(logqx) / stop_gradient(exp(logqx)), axis=0)
- # This is nice because it means the result is exactly the Csiszar
- # f-Divergence yet the gradient is unbiased. However its numerically
- # unstable since the q is not in log-domain.
- return math_ops.reduce_mean(logqx * array_ops.stop_gradient(fx) + fx,
- axis=0)
+ raise ValueError(
+ "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
+ "transformation of a parameterless distribution. (Otherwise this "
+ "function has a biased gradient.)")
+ return monte_carlo.expectation_v2(
+ f=lambda x: f(p.log_prob(x) - q.log_prob(x)),
+ samples=q.sample(num_draws, seed=seed),
+ log_prob=q.log_prob,
+ use_reparametrization=use_reparametrization)