aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py')
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py62
1 files changed, 39 insertions, 23 deletions
diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
index 09389e5d38..7b51d8d932 100644
--- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py
@@ -40,8 +40,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib import framework as contrib_framework
-from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import distribution
@@ -750,7 +750,7 @@ def monte_carlo_csiszar_f_divergence(
```none
D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
- where x_j ~iid q(X)
+ where x_j ~iid q(x)
```
Tricks: Reparameterization and Score-Gradient
@@ -759,8 +759,8 @@ def monte_carlo_csiszar_f_divergence(
parameterless distribution (e.g.,
`Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
expectation, i.e.,
- `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}`
- and `s_i = f(x_i), x_i ~iid q(X)`.
+ `nabla Avg{ s_i : i=1...n } = Avg{ nabla s_i : i=1...n }` where `S_n=Avg{s_i}`
+ and `s_i = f(x_i), x_i ~ q`.
However, if q is not reparameterized, TensorFlow's gradient will be incorrect
since the chain-rule stops at samples of unreparameterized distributions. In
@@ -768,17 +768,22 @@ def monte_carlo_csiszar_f_divergence(
gradient, i.e.,
```none
- grad[ E_q[f(X)] ]
- = grad[ int dx q(x) f(x) ]
- = int dx grad[ q(x) f(x) ]
- = int dx [ q'(x) f(x) + q(x) f'(x) ]
+ nabla E_q[f(X)]
+ = nabla int dx q(x) f(x)
+ = int dx nabla [ q(x) f(x) ]
+ = int dx q'(x) f(x) + q(x) f'(x)
= int dx q(x) [q'(x) / q(x) f(x) + f'(x) ]
- = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ]
- = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
+ = int dx q(x) nabla [ log(q(x)) stopgrad[f(x)] + f(x) ]
+ = E_q[ nabla [ log(q(X)) stopgrad[f(X)] + f(X) ] ]
+ ~= Avg{ log(q(y_i)) stopgrad[f(y_i)] + f(y_i) : y_i = stopgrad[x_i], x_i ~ q}
```
Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is
- usually preferable to set `use_reparametrization = True`.
+ usually preferable to `use_reparametrization = True`.
+
+ Warning: using `use_reparametrization = False` will mean that the result is
+ *not* the Csiszar f-Divergence. However its expected gradient *is* the
+ gradient of the Csiszar f-Divergence.
Example Application:
@@ -812,7 +817,10 @@ def monte_carlo_csiszar_f_divergence(
Returns:
monte_carlo_csiszar_f_divergence: Floating-type `Tensor` Monte Carlo
- approximation of the Csiszar f-Divergence.
+ approximation of the Csiszar f-Divergence. Warning: using
+ `use_reparametrization = False` will mean that the result is *not* the
+ Csiszar f-Divergence. However its expected gradient *is* the actual
+ gradient of the Csiszar f-Divergence.
Raises:
ValueError: if `q` is not a reparameterized distribution and
@@ -823,16 +831,24 @@ def monte_carlo_csiszar_f_divergence(
to parameters) is valid.
"""
with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]):
- if (use_reparametrization and
- q.reparameterization_type != distribution.FULLY_REPARAMETERIZED):
+ x = q.sample(num_draws, seed=seed)
+ if use_reparametrization:
# TODO(jvdillon): Consider only raising an exception if the gradient is
# requested.
- raise ValueError(
- "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
- "transformation of a parameterless distribution. (Otherwise this "
- "function has a biased gradient.)")
- return monte_carlo.expectation_v2(
- f=lambda x: f(p.log_prob(x) - q.log_prob(x)),
- samples=q.sample(num_draws, seed=seed),
- log_prob=q.log_prob,
- use_reparametrization=use_reparametrization)
+ if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED:
+ raise ValueError(
+ "Distribution `q` must be reparameterized, i.e., a diffeomorphic "
+ "transformation of a parameterless distribution. (Otherwise this "
+ "function has a biased gradient.)")
+ return math_ops.reduce_mean(f(p.log_prob(x) - q.log_prob(x)), axis=0)
+ else:
+ x = array_ops.stop_gradient(x)
+ logqx = q.log_prob(x)
+ fx = f(p.log_prob(x) - logqx)
+ # Alternatively we could have returned:
+ # reduce_mean(fx * exp(logqx) / stop_gradient(exp(logqx)), axis=0)
+ # This is nice because it means the result is exactly the Csiszar
+ # f-Divergence yet the gradient is unbiased. However its numerically
+ # unstable since the q is not in log-domain.
+ return math_ops.reduce_mean(logqx * array_ops.stop_gradient(fx) + fx,
+ axis=0)