diff options
Diffstat (limited to 'tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py')
-rw-r--r-- | tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py | 62 |
1 files changed, 39 insertions, 23 deletions
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py index 09389e5d38..7b51d8d932 100644 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py @@ -40,8 +40,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import framework as contrib_framework -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution @@ -750,7 +750,7 @@ def monte_carlo_csiszar_f_divergence( ```none D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), - where x_j ~iid q(X) + where x_j ~iid q(x) ``` Tricks: Reparameterization and Score-Gradient @@ -759,8 +759,8 @@ def monte_carlo_csiszar_f_divergence( parameterless distribution (e.g., `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and expectation, i.e., - `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}` - and `s_i = f(x_i), x_i ~iid q(X)`. + `nabla Avg{ s_i : i=1...n } = Avg{ nabla s_i : i=1...n }` where `S_n=Avg{s_i}` + and `s_i = f(x_i), x_i ~ q`. However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In @@ -768,17 +768,22 @@ def monte_carlo_csiszar_f_divergence( gradient, i.e., ```none - grad[ E_q[f(X)] ] - = grad[ int dx q(x) f(x) ] - = int dx grad[ q(x) f(x) ] - = int dx [ q'(x) f(x) + q(x) f'(x) ] + nabla E_q[f(X)] + = nabla int dx q(x) f(x) + = int dx nabla [ q(x) f(x) ] + = int dx q'(x) f(x) + q(x) f'(x) = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] - = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] - = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ] + = int dx q(x) nabla [ log(q(x)) stopgrad[f(x)] + f(x) ] + = E_q[ nabla [ log(q(X)) stopgrad[f(X)] + f(X) ] ] + ~= Avg{ log(q(y_i)) stopgrad[f(y_i)] + f(y_i) : y_i = stopgrad[x_i], x_i ~ q} ``` Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is - usually preferable to set `use_reparametrization = True`. + usually preferable to `use_reparametrization = True`. + + Warning: using `use_reparametrization = False` will mean that the result is + *not* the Csiszar f-Divergence. However its expected gradient *is* the + gradient of the Csiszar f-Divergence. Example Application: @@ -812,7 +817,10 @@ def monte_carlo_csiszar_f_divergence( Returns: monte_carlo_csiszar_f_divergence: Floating-type `Tensor` Monte Carlo - approximation of the Csiszar f-Divergence. + approximation of the Csiszar f-Divergence. Warning: using + `use_reparametrization = False` will mean that the result is *not* the + Csiszar f-Divergence. However its expected gradient *is* the actual + gradient of the Csiszar f-Divergence. Raises: ValueError: if `q` is not a reparameterized distribution and @@ -823,16 +831,24 @@ def monte_carlo_csiszar_f_divergence( to parameters) is valid. """ with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - if (use_reparametrization and - q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): + x = q.sample(num_draws, seed=seed) + if use_reparametrization: # TODO(jvdillon): Consider only raising an exception if the gradient is # requested. - raise ValueError( - "Distribution `q` must be reparameterized, i.e., a diffeomorphic " - "transformation of a parameterless distribution. (Otherwise this " - "function has a biased gradient.)") - return monte_carlo.expectation_v2( - f=lambda x: f(p.log_prob(x) - q.log_prob(x)), - samples=q.sample(num_draws, seed=seed), - log_prob=q.log_prob, - use_reparametrization=use_reparametrization) + if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED: + raise ValueError( + "Distribution `q` must be reparameterized, i.e., a diffeomorphic " + "transformation of a parameterless distribution. (Otherwise this " + "function has a biased gradient.)") + return math_ops.reduce_mean(f(p.log_prob(x) - q.log_prob(x)), axis=0) + else: + x = array_ops.stop_gradient(x) + logqx = q.log_prob(x) + fx = f(p.log_prob(x) - logqx) + # Alternatively we could have returned: + # reduce_mean(fx * exp(logqx) / stop_gradient(exp(logqx)), axis=0) + # This is nice because it means the result is exactly the Csiszar + # f-Divergence yet the gradient is unbiased. However its numerically + # unstable since the q is not in log-domain. + return math_ops.reduce_mean(logqx * array_ops.stop_gradient(fx) + fx, + axis=0) |