diff options
author | Joshua V. Dillon <jvdillon@google.com> | 2017-06-22 16:50:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-22 16:55:03 -0700 |
commit | b85601b95eba28605d3de076fa70cabf2f2e32b9 (patch) | |
tree | 59feb5dc4af4a8bb481aef0d4e1d17632713f1bc | |
parent | 07678fef5510d4a7c89d28b222ce72df49456a97 (diff) |
Improve score-trick to be a valid Csiszar f-Divergence yet numerically stable.
PiperOrigin-RevId: 159896013
-rw-r--r-- | tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py | 43 | ||||
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py | 62 |
2 files changed, 55 insertions, 50 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py index fabf7a9b77..fba0cc6522 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py @@ -627,6 +627,11 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): grad = lambda fs: gradients_impl.gradients(fs, s)[0] [ + approx_kl_grad_, + approx_kl_self_normalized_grad_, + approx_kl_score_trick_grad_, + approx_kl_self_normalized_score_trick_grad_, + exact_kl_grad_, approx_kl_, approx_kl_self_normalized_, approx_kl_score_trick_, @@ -638,23 +643,39 @@ class MonteCarloCsiszarFDivergenceTest(test.TestCase): grad(approx_kl_score_trick), grad(approx_kl_self_normalized_score_trick), grad(exact_kl), + approx_kl, + approx_kl_self_normalized, + approx_kl_score_trick, + approx_kl_self_normalized_score_trick, + exact_kl, ]) - self.assertAllClose( - approx_kl_, exact_kl_, - rtol=0.06, atol=0.) + # Test average divergence. + self.assertAllClose(approx_kl_, exact_kl_, + rtol=0.02, atol=0.) - self.assertAllClose( - approx_kl_self_normalized_, exact_kl_, - rtol=0.05, atol=0.) + self.assertAllClose(approx_kl_self_normalized_, exact_kl_, + rtol=0.08, atol=0.) - self.assertAllClose( - approx_kl_score_trick_, exact_kl_, - rtol=0.06, atol=0.) + self.assertAllClose(approx_kl_score_trick_, exact_kl_, + rtol=0.02, atol=0.) + + self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_, + rtol=0.08, atol=0.) + + # Test average gradient-divergence. + self.assertAllClose(approx_kl_grad_, exact_kl_grad_, + rtol=0.007, atol=0.) + + self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_, + rtol=0.011, atol=0.) + + self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_, + rtol=0.018, atol=0.) self.assertAllClose( - approx_kl_self_normalized_score_trick_, exact_kl_, - rtol=0.05, atol=0.) + approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_, + rtol=0.017, atol=0.) if __name__ == '__main__': diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py index 7b51d8d932..09389e5d38 100644 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py @@ -40,8 +40,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import framework as contrib_framework +from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution @@ -750,7 +750,7 @@ def monte_carlo_csiszar_f_divergence( ```none D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), - where x_j ~iid q(x) + where x_j ~iid q(X) ``` Tricks: Reparameterization and Score-Gradient @@ -759,8 +759,8 @@ def monte_carlo_csiszar_f_divergence( parameterless distribution (e.g., `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and expectation, i.e., - `nabla Avg{ s_i : i=1...n } = Avg{ nabla s_i : i=1...n }` where `S_n=Avg{s_i}` - and `s_i = f(x_i), x_i ~ q`. + `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}` + and `s_i = f(x_i), x_i ~iid q(X)`. However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In @@ -768,22 +768,17 @@ def monte_carlo_csiszar_f_divergence( gradient, i.e., ```none - nabla E_q[f(X)] - = nabla int dx q(x) f(x) - = int dx nabla [ q(x) f(x) ] - = int dx q'(x) f(x) + q(x) f'(x) + grad[ E_q[f(X)] ] + = grad[ int dx q(x) f(x) ] + = int dx grad[ q(x) f(x) ] + = int dx [ q'(x) f(x) + q(x) f'(x) ] = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] - = int dx q(x) nabla [ log(q(x)) stopgrad[f(x)] + f(x) ] - = E_q[ nabla [ log(q(X)) stopgrad[f(X)] + f(X) ] ] - ~= Avg{ log(q(y_i)) stopgrad[f(y_i)] + f(y_i) : y_i = stopgrad[x_i], x_i ~ q} + = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] + = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ] ``` Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is - usually preferable to `use_reparametrization = True`. - - Warning: using `use_reparametrization = False` will mean that the result is - *not* the Csiszar f-Divergence. However its expected gradient *is* the - gradient of the Csiszar f-Divergence. + usually preferable to set `use_reparametrization = True`. Example Application: @@ -817,10 +812,7 @@ def monte_carlo_csiszar_f_divergence( Returns: monte_carlo_csiszar_f_divergence: Floating-type `Tensor` Monte Carlo - approximation of the Csiszar f-Divergence. Warning: using - `use_reparametrization = False` will mean that the result is *not* the - Csiszar f-Divergence. However its expected gradient *is* the actual - gradient of the Csiszar f-Divergence. + approximation of the Csiszar f-Divergence. Raises: ValueError: if `q` is not a reparameterized distribution and @@ -831,24 +823,16 @@ def monte_carlo_csiszar_f_divergence( to parameters) is valid. """ with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - x = q.sample(num_draws, seed=seed) - if use_reparametrization: + if (use_reparametrization and + q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): # TODO(jvdillon): Consider only raising an exception if the gradient is # requested. - if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED: - raise ValueError( - "Distribution `q` must be reparameterized, i.e., a diffeomorphic " - "transformation of a parameterless distribution. (Otherwise this " - "function has a biased gradient.)") - return math_ops.reduce_mean(f(p.log_prob(x) - q.log_prob(x)), axis=0) - else: - x = array_ops.stop_gradient(x) - logqx = q.log_prob(x) - fx = f(p.log_prob(x) - logqx) - # Alternatively we could have returned: - # reduce_mean(fx * exp(logqx) / stop_gradient(exp(logqx)), axis=0) - # This is nice because it means the result is exactly the Csiszar - # f-Divergence yet the gradient is unbiased. However its numerically - # unstable since the q is not in log-domain. - return math_ops.reduce_mean(logqx * array_ops.stop_gradient(fx) + fx, - axis=0) + raise ValueError( + "Distribution `q` must be reparameterized, i.e., a diffeomorphic " + "transformation of a parameterless distribution. (Otherwise this " + "function has a biased gradient.)") + return monte_carlo.expectation_v2( + f=lambda x: f(p.log_prob(x) - q.log_prob(x)), + samples=q.sample(num_draws, seed=seed), + log_prob=q.log_prob, + use_reparametrization=use_reparametrization) |