aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2017-02-16 15:50:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-16 16:03:41 -0800
commit7a1e163cbb41d2e682cfe6b6941a50be72f11b96 (patch)
treee45eeaf2b4aea660845041f2f1ffbfb921cbf3f0
parent393d3d92d71f195ecbd60ea5bb0885071c6a20c9 (diff)
Task 6: sealing up bayesflow.{entropy, monte_carlo, variational_inference}
Change: 147778589
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py11
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py4
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py4
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/entropy.py372
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/entropy_impl.py384
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo.py252
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py274
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/variational_inference.py311
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py327
-rw-r--r--tensorflow/tools/docs/generate.py11
11 files changed, 1025 insertions, 927 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py
index d219cd569c..57a38bd5f9 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py
@@ -29,7 +29,7 @@ import numpy as np
from tensorflow.contrib import distributions as distributions_lib
from tensorflow.contrib import layers as layers_lib
-from tensorflow.contrib.bayesflow.python.ops import entropy as entropy_lib
+from tensorflow.contrib.bayesflow.python.ops import entropy_impl as entropy_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index db8275ce0c..12c05e34e4 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -27,7 +27,8 @@ if hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'):
from tensorflow.contrib import distributions as distributions_lib
from tensorflow.contrib import layers as layers_lib
-from tensorflow.contrib.bayesflow.python.ops import monte_carlo as monte_carlo_lib
+from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib
+from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
@@ -156,7 +157,7 @@ class GetSamplesTest(test.TestCase):
n = None
seed = None
with self.assertRaisesRegexp(ValueError, 'exactly one'):
- monte_carlo._get_samples(dist, z, n, seed)
+ _get_samples(dist, z, n, seed)
def test_raises_if_both_z_and_n_are_not_none(self):
with self.test_session():
@@ -165,7 +166,7 @@ class GetSamplesTest(test.TestCase):
n = 1
seed = None
with self.assertRaisesRegexp(ValueError, 'exactly one'):
- monte_carlo._get_samples(dist, z, n, seed)
+ _get_samples(dist, z, n, seed)
def test_returns_n_samples_if_n_provided(self):
with self.test_session():
@@ -173,7 +174,7 @@ class GetSamplesTest(test.TestCase):
z = None
n = 10
seed = None
- z = monte_carlo._get_samples(dist, z, n, seed)
+ z = _get_samples(dist, z, n, seed)
self.assertEqual((10,), z.get_shape())
def test_returns_z_if_z_provided(self):
@@ -182,7 +183,7 @@ class GetSamplesTest(test.TestCase):
z = dist.sample(10, seed=42)
n = None
seed = None
- z = monte_carlo._get_samples(dist, z, n, seed)
+ z = _get_samples(dist, z, n, seed)
self.assertEqual((10,), z.get_shape())
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py
index 7bdd0a3269..9ee59a03ca 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py
@@ -22,7 +22,7 @@ import numpy as np
from tensorflow.contrib import distributions
from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor
from tensorflow.contrib.bayesflow.python.ops import stochastic_variables
-from tensorflow.contrib.bayesflow.python.ops import variational_inference
+from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -33,7 +33,7 @@ from tensorflow.python.platform import test
sv = stochastic_variables
st = stochastic_tensor
-vi = variational_inference
+vi = variational_inference_impl
dist = distributions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py
index 49ece025f2..12eb66b65d 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py
@@ -28,7 +28,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
from tensorflow.contrib import distributions as distributions_lib
from tensorflow.contrib import layers
from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor
-from tensorflow.contrib.bayesflow.python.ops import variational_inference
+from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import normal
from tensorflow.python.framework import constant_op
@@ -38,7 +38,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
st = stochastic_tensor
-vi = variational_inference
+vi = variational_inference_impl
distributions = distributions_lib
diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy.py b/tensorflow/contrib/bayesflow/python/ops/entropy.py
index 4df6d5a911..a22e1c1d4e 100644
--- a/tensorflow/contrib/bayesflow/python/ops/entropy.py
+++ b/tensorflow/contrib/bayesflow/python/ops/entropy.py
@@ -1,4 +1,4 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,372 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Support for Entropy Ops. See ${python/contrib.bayesflow.entropy}.
-
-@@elbo_ratio
-@@entropy_shannon
-@@renyi_ratio
-@@renyi_alpha
-"""
+"""Support for Entropy Ops. See ${python/contrib.bayesflow.entropy}."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-
-from tensorflow.contrib.bayesflow.python.ops import monte_carlo
-from tensorflow.contrib.bayesflow.python.ops import variational_inference
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import tf_logging as logging
-
-# Make utility functions from monte_carlo available.
-# pylint: disable=protected-access
-_get_samples = monte_carlo._get_samples
-_logspace_mean = monte_carlo._logspace_mean
-_sample_mean = monte_carlo._sample_mean
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.entropy_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=protected-access
-
-__all__ = [
- 'elbo_ratio',
- 'entropy_shannon',
- 'renyi_ratio',
- 'renyi_alpha',
+_allowed_symbols = [
+ 'ELBOForms', 'elbo_ratio', 'entropy_shannon', 'renyi_ratio', 'renyi_alpha'
]
-
-ELBOForms = variational_inference.ELBOForms # pylint: disable=invalid-name
-
-
-def elbo_ratio(log_p,
- q,
- z=None,
- n=None,
- seed=None,
- form=None,
- name='elbo_ratio'):
- r"""Estimate of the ratio appearing in the `ELBO` and `KL` divergence.
-
- With `p(z) := exp{log_p(z)}`, this `Op` returns an approximation of
-
- ```
- E_q[ Log[p(Z) / q(Z)] ]
- ```
-
- The term `E_q[ Log[p(Z)] ]` is always computed as a sample mean.
- The term `E_q[ Log[q(z)] ]` can be computed with samples, or an exact formula
- if `q.entropy()` is defined. This is controlled with the kwarg `form`.
-
- This log-ratio appears in different contexts:
-
- #### `KL[q || p]`
-
- If `log_p(z) = Log[p(z)]` for distribution `p`, this `Op` approximates
- the negative Kullback-Leibler divergence.
-
- ```
- elbo_ratio(log_p, q, n=100) = -1 * KL[q || p],
- KL[q || p] = E[ Log[q(Z)] - Log[p(Z)] ]
- ```
-
- Note that if `p` is a `Distribution`, then `distributions.kl(q, p)` may be
- defined and available as an exact result.
-
- #### ELBO
-
- If `log_p(z) = Log[p(z, x)]` is the log joint of a distribution `p`, this is
- the Evidence Lower BOund (ELBO):
-
- ```
- ELBO ~= E[ Log[p(Z, x)] - Log[q(Z)] ]
- = Log[p(x)] - KL[q || p]
- <= Log[p(x)]
- ```
-
- User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
-
- Args:
- log_p: Callable mapping samples from `q` to `Tensors` with
- shape broadcastable to `q.batch_shape`.
- For example, `log_p` works "just like" `q.log_prob`.
- q: `tf.contrib.distributions.Distribution`.
- z: `Tensor` of samples from `q`, produced by `q.sample(n)` for some `n`.
- n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
- seed: Python integer to seed the random number generator.
- form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`)
- or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default`
- (attempt analytic entropy, fallback on sample).
- Default value is `ELBOForms.default`.
- name: A name to give this `Op`.
-
- Returns:
- Scalar `Tensor` holding sample mean KL divergence. `shape` is the batch
- shape of `q`, and `dtype` is the same as `q`.
-
- Raises:
- ValueError: If `form` is not handled by this function.
- """
- form = ELBOForms.default if form is None else form
-
- with ops.name_scope(name, values=[n, z]):
- z = _get_samples(q, z, n, seed)
-
- entropy = entropy_shannon(q, z=z, form=form)
-
- # If log_p(z) = Log[p(z)], cross entropy = -E_q[log(p(Z))]
- negative_cross_entropy = _sample_mean(log_p(z))
-
- return entropy + negative_cross_entropy
-
-
-def entropy_shannon(p,
- z=None,
- n=None,
- seed=None,
- form=None,
- name='entropy_shannon'):
- r"""Monte Carlo or deterministic computation of Shannon's entropy.
-
- Depending on the kwarg `form`, this `Op` returns either the analytic entropy
- of the distribution `p`, or the sampled entropy:
-
- ```
- -n^{-1} sum_{i=1}^n p.log_prob(z_i), where z_i ~ p,
- \approx - E_p[ Log[p(Z)] ]
- = Entropy[p]
- ```
-
- User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
-
- Args:
- p: `tf.contrib.distributions.Distribution`
- z: `Tensor` of samples from `p`, produced by `p.sample(n)` for some `n`.
- n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
- seed: Python integer to seed the random number generator.
- form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`)
- or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default`
- (attempt analytic entropy, fallback on sample).
- Default value is `ELBOForms.default`.
- name: A name to give this `Op`.
-
- Returns:
- A `Tensor` with same `dtype` as `p`, and shape equal to `p.batch_shape`.
-
- Raises:
- ValueError: If `form` not handled by this function.
- ValueError: If `form` is `ELBOForms.analytic_entropy` and `n` was provided.
- """
- form = ELBOForms.default if form is None else form
-
- if n is not None and form == ELBOForms.analytic_entropy:
- raise ValueError('If form == ELBOForms.analytic_entropy, n must be None.')
-
- with ops.name_scope(name, values=[n, z]):
- # Entropy: -E_p[log(p(Z))].
- entropy = None
-
- # Try analytic path
- if form in [ELBOForms.default, ELBOForms.analytic_entropy]:
- try:
- entropy = p.entropy()
- logging.info('Using analytic entropy(p:%s)', p)
- except NotImplementedError as e:
- if form == ELBOForms.analytic_entropy:
- raise e
- elif form != ELBOForms.sample:
- raise ValueError('ELBOForm not handled by this function: %s' % form)
-
- # Sample path
- if entropy is None:
- logging.info('Using sampled entropy(p:%s)', p)
- entropy = -1. * monte_carlo.expectation(
- p.log_prob, p, z=z, n=n, seed=seed)
-
- return entropy
-
-
-def renyi_ratio(log_p, q, alpha, z=None, n=None, seed=None, name='renyi_ratio'):
- r"""Monte Carlo estimate of the ratio appearing in Renyi divergence.
-
- This can be used to compute the Renyi (alpha) divergence, or a log evidence
- approximation based on Renyi divergence.
-
- #### Definition
-
- With `z_i` iid samples from `q`, and `exp{log_p(z)} = p(z)`, this `Op` returns
- the (biased for finite `n`) estimate:
-
- ```
- (1 - alpha)^{-1} Log[ n^{-1} sum_{i=1}^n ( p(z_i) / q(z_i) )^{1 - alpha},
- \approx (1 - alpha)^{-1} Log[ E_q[ (p(Z) / q(Z))^{1 - alpha} ] ]
- ```
-
- This ratio appears in different contexts:
-
- #### Renyi divergence
-
- If `log_p(z) = Log[p(z)]` is the log prob of a distribution, and
- `alpha > 0`, `alpha != 1`, this `Op` approximates `-1` times Renyi divergence:
-
- ```
- # Choose reasonably high n to limit bias, see below.
- renyi_ratio(log_p, q, alpha, n=100)
- \approx -1 * D_alpha[q || p], where
- D_alpha[q || p] := (1 - alpha)^{-1} Log E_q[(p(Z) / q(Z))^{1 - alpha}]
- ```
-
- The Renyi (or "alpha") divergence is non-negative and equal to zero iff
- `q = p`. Various limits of `alpha` lead to different special case results:
-
- ```
- alpha D_alpha[q || p]
- ----- ---------------
- --> 0 Log[ int_{q > 0} p(z) dz ]
- = 0.5, -2 Log[1 - Hel^2[q || p]], (\propto squared Hellinger distance)
- --> 1 KL[q || p]
- = 2 Log[ 1 + chi^2[q || p] ], (\propto squared Chi-2 divergence)
- --> infty Log[ max_z{q(z) / p(z)} ], (min description length principle).
- ```
-
- See "Renyi Divergence Variational Inference", by Li and Turner.
-
- #### Log evidence approximation
-
- If `log_p(z) = Log[p(z, x)]` is the log of the joint distribution `p`, this is
- an alternative to the ELBO common in variational inference.
-
- ```
- L_alpha(q, p) = Log[p(x)] - D_alpha[q || p]
- ```
-
- If `q` and `p` have the same support, and `0 < a <= b < 1`, one can show
- `ELBO <= D_b <= D_a <= Log[p(x)]`. Thus, this `Op` allows a smooth
- interpolation between the ELBO and the true evidence.
-
- #### Stability notes
-
- Note that when `1 - alpha` is not small, the ratio `(p(z) / q(z))^{1 - alpha}`
- is subject to underflow/overflow issues. For that reason, it is evaluated in
- log-space after centering. Nonetheless, infinite/NaN results may occur. For
- that reason, one may wish to shrink `alpha` gradually. See the `Op`
- `renyi_alpha`. Using `float64` will also help.
-
-
- #### Bias for finite sample size
-
- Due to nonlinearity of the logarithm, for random variables `{X_1,...,X_n}`,
- `E[ Log[sum_{i=1}^n X_i] ] != Log[ E[sum_{i=1}^n X_i] ]`. As a result, this
- estimate is biased for finite `n`. For `alpha < 1`, it is non-decreasing
- with `n` (in expectation). For example, if `n = 1`, this estimator yields the
- same result as `elbo_ratio`, and as `n` increases the expected value
- of the estimator increases.
-
- #### Call signature
-
- User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
-
- Args:
- log_p: Callable mapping samples from `q` to `Tensors` with
- shape broadcastable to `q.batch_shape`.
- For example, `log_p` works "just like" `q.log_prob`.
- q: `tf.contrib.distributions.Distribution`.
- `float64` `dtype` recommended.
- `log_p` and `q` should be supported on the same set.
- alpha: `Tensor` with shape `q.batch_shape` and values not equal to 1.
- z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
- n: Integer `Tensor`. The number of samples to use if `z` is not provided.
- Note that this can be highly biased for small `n`, see docstring.
- seed: Python integer to seed the random number generator.
- name: A name to give this `Op`.
-
- Returns:
- renyi_result: The scaled log of sample mean. `Tensor` with `shape` equal
- to batch shape of `q`, and `dtype` = `q.dtype`.
- """
- with ops.name_scope(name, values=[alpha, n, z]):
- z = _get_samples(q, z, n, seed)
-
- # Evaluate sample mean in logspace. Note that _logspace_mean will compute
- # (among other things) the mean of q.log_prob(z), which could also be
- # obtained with q.entropy(). However, DON'T use analytic entropy, because
- # that increases variance, and could result in NaN/Inf values of a sensitive
- # term.
-
- # log_values
- # = (1 - alpha) * ( Log p - Log q )
- log_values = (1. - alpha) * (log_p(z) - q.log_prob(z))
-
- # log_mean_values
- # = Log[ E[ values ] ]
- # = Log[ E[ (p / q)^{1-alpha} ] ]
- log_mean_values = _logspace_mean(log_values)
-
- return log_mean_values / (1. - alpha)
-
-
-def renyi_alpha(step,
- decay_time,
- alpha_min,
- alpha_max=0.99999,
- name='renyi_alpha'):
- r"""Exponentially decaying `Tensor` appropriate for Renyi ratios.
-
- When minimizing the Renyi divergence for `0 <= alpha < 1` (or maximizing the
- Renyi equivalent of elbo) in high dimensions, it is not uncommon to experience
- `NaN` and `inf` values when `alpha` is far from `1`.
-
- For that reason, it is often desirable to start the optimization with `alpha`
- very close to 1, and reduce it to a final `alpha_min` according to some
- schedule. The user may even want to optimize using `elbo_ratio` for
- some fixed time before switching to Renyi based methods.
-
- This `Op` returns an `alpha` decaying exponentially with step:
-
- ```
- s(step) = (exp{step / decay_time} - 1) / (e - 1)
- t(s) = max(0, min(s, 1)), (smooth growth from 0 to 1)
- alpha(t) = (1 - t) alpha_min + t alpha_max
- ```
-
- Args:
- step: Non-negative scalar `Tensor`. Typically the global step or an
- offset version thereof.
- decay_time: Positive scalar `Tensor`.
- alpha_min: `float` or `double` `Tensor`.
- The minimal, final value of `alpha`, achieved when `step >= decay_time`
- alpha_max: `Tensor` of same `dtype` as `alpha_min`.
- The maximal, beginning value of `alpha`, achieved when `step == 0`
- name: A name to give this `Op`.
-
- Returns:
- alpha: A `Tensor` of same `dtype` as `alpha_min`.
- """
- with ops.name_scope(name, values=[step, decay_time, alpha_min, alpha_max]):
- alpha_min = ops.convert_to_tensor(alpha_min, name='alpha_min')
- dtype = alpha_min.dtype
-
- alpha_max = ops.convert_to_tensor(alpha_max, dtype=dtype, name='alpha_max')
- decay_time = math_ops.cast(decay_time, dtype)
- step = math_ops.cast(step, dtype)
-
- check_scalars = [
- check_ops.assert_rank(step, 0, message='step must be scalar'),
- check_ops.assert_rank(
- decay_time, 0, message='decay_time must be scalar'),
- check_ops.assert_rank(alpha_min, 0, message='alpha_min must be scalar'),
- check_ops.assert_rank(alpha_max, 0, message='alpha_max must be scalar'),
- ]
- check_sign = [
- check_ops.assert_non_negative(
- step, message='step must be non-negative'),
- check_ops.assert_positive(
- decay_time, message='decay_time must be positive'),
- ]
-
- with ops.control_dependencies(check_scalars + check_sign):
- theta = (math_ops.exp(step / decay_time) - 1.) / (math.e - 1.)
- theta = math_ops.minimum(math_ops.maximum(theta, 0.), 1.)
- return alpha_max * (1. - theta) + alpha_min * theta
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py
new file mode 100644
index 0000000000..ef9fb73025
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py
@@ -0,0 +1,384 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for Entropy Ops. See ${python/contrib.bayesflow.entropy}.
+
+@@elbo_ratio
+@@entropy_shannon
+@@renyi_ratio
+@@renyi_alpha
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo
+from tensorflow.contrib.bayesflow.python.ops import variational_inference
+from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples as get_samples
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging as logging
+
+
+# Make utility functions from monte_carlo available.
+# pylint: disable=protected-access
+_get_samples = get_samples
+_logspace_mean = monte_carlo._logspace_mean
+_sample_mean = monte_carlo._sample_mean
+
+# pylint: enable=protected-access
+
+__all__ = [
+ 'elbo_ratio',
+ 'entropy_shannon',
+ 'renyi_ratio',
+ 'renyi_alpha',
+]
+
+ELBOForms = variational_inference.ELBOForms # pylint: disable=invalid-name
+
+
+def elbo_ratio(log_p,
+ q,
+ z=None,
+ n=None,
+ seed=None,
+ form=None,
+ name='elbo_ratio'):
+ r"""Estimate of the ratio appearing in the `ELBO` and `KL` divergence.
+
+ With `p(z) := exp{log_p(z)}`, this `Op` returns an approximation of
+
+ ```
+ E_q[ Log[p(Z) / q(Z)] ]
+ ```
+
+ The term `E_q[ Log[p(Z)] ]` is always computed as a sample mean.
+ The term `E_q[ Log[q(z)] ]` can be computed with samples, or an exact formula
+ if `q.entropy()` is defined. This is controlled with the kwarg `form`.
+
+ This log-ratio appears in different contexts:
+
+ #### `KL[q || p]`
+
+ If `log_p(z) = Log[p(z)]` for distribution `p`, this `Op` approximates
+ the negative Kullback-Leibler divergence.
+
+ ```
+ elbo_ratio(log_p, q, n=100) = -1 * KL[q || p],
+ KL[q || p] = E[ Log[q(Z)] - Log[p(Z)] ]
+ ```
+
+ Note that if `p` is a `Distribution`, then `distributions.kl(q, p)` may be
+ defined and available as an exact result.
+
+ #### ELBO
+
+ If `log_p(z) = Log[p(z, x)]` is the log joint of a distribution `p`, this is
+ the Evidence Lower BOund (ELBO):
+
+ ```
+ ELBO ~= E[ Log[p(Z, x)] - Log[q(Z)] ]
+ = Log[p(x)] - KL[q || p]
+ <= Log[p(x)]
+ ```
+
+ User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
+
+ Args:
+ log_p: Callable mapping samples from `q` to `Tensors` with
+ shape broadcastable to `q.batch_shape`.
+ For example, `log_p` works "just like" `q.log_prob`.
+ q: `tf.contrib.distributions.Distribution`.
+ z: `Tensor` of samples from `q`, produced by `q.sample(n)` for some `n`.
+ n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
+ seed: Python integer to seed the random number generator.
+ form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`)
+ or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default`
+ (attempt analytic entropy, fallback on sample).
+ Default value is `ELBOForms.default`.
+ name: A name to give this `Op`.
+
+ Returns:
+ Scalar `Tensor` holding sample mean KL divergence. `shape` is the batch
+ shape of `q`, and `dtype` is the same as `q`.
+
+ Raises:
+ ValueError: If `form` is not handled by this function.
+ """
+ form = ELBOForms.default if form is None else form
+
+ with ops.name_scope(name, values=[n, z]):
+ z = _get_samples(q, z, n, seed)
+
+ entropy = entropy_shannon(q, z=z, form=form)
+
+ # If log_p(z) = Log[p(z)], cross entropy = -E_q[log(p(Z))]
+ negative_cross_entropy = _sample_mean(log_p(z))
+
+ return entropy + negative_cross_entropy
+
+
+def entropy_shannon(p,
+ z=None,
+ n=None,
+ seed=None,
+ form=None,
+ name='entropy_shannon'):
+ r"""Monte Carlo or deterministic computation of Shannon's entropy.
+
+ Depending on the kwarg `form`, this `Op` returns either the analytic entropy
+ of the distribution `p`, or the sampled entropy:
+
+ ```
+ -n^{-1} sum_{i=1}^n p.log_prob(z_i), where z_i ~ p,
+ \approx - E_p[ Log[p(Z)] ]
+ = Entropy[p]
+ ```
+
+ User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
+
+ Args:
+ p: `tf.contrib.distributions.Distribution`
+ z: `Tensor` of samples from `p`, produced by `p.sample(n)` for some `n`.
+ n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
+ seed: Python integer to seed the random number generator.
+ form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`)
+ or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default`
+ (attempt analytic entropy, fallback on sample).
+ Default value is `ELBOForms.default`.
+ name: A name to give this `Op`.
+
+ Returns:
+ A `Tensor` with same `dtype` as `p`, and shape equal to `p.batch_shape`.
+
+ Raises:
+ ValueError: If `form` not handled by this function.
+ ValueError: If `form` is `ELBOForms.analytic_entropy` and `n` was provided.
+ """
+ form = ELBOForms.default if form is None else form
+
+ if n is not None and form == ELBOForms.analytic_entropy:
+ raise ValueError('If form == ELBOForms.analytic_entropy, n must be None.')
+
+ with ops.name_scope(name, values=[n, z]):
+ # Entropy: -E_p[log(p(Z))].
+ entropy = None
+
+ # Try analytic path
+ if form in [ELBOForms.default, ELBOForms.analytic_entropy]:
+ try:
+ entropy = p.entropy()
+ logging.info('Using analytic entropy(p:%s)', p)
+ except NotImplementedError as e:
+ if form == ELBOForms.analytic_entropy:
+ raise e
+ elif form != ELBOForms.sample:
+ raise ValueError('ELBOForm not handled by this function: %s' % form)
+
+ # Sample path
+ if entropy is None:
+ logging.info('Using sampled entropy(p:%s)', p)
+ entropy = -1. * monte_carlo.expectation(
+ p.log_prob, p, z=z, n=n, seed=seed)
+
+ return entropy
+
+
+def renyi_ratio(log_p, q, alpha, z=None, n=None, seed=None, name='renyi_ratio'):
+ r"""Monte Carlo estimate of the ratio appearing in Renyi divergence.
+
+ This can be used to compute the Renyi (alpha) divergence, or a log evidence
+ approximation based on Renyi divergence.
+
+ #### Definition
+
+ With `z_i` iid samples from `q`, and `exp{log_p(z)} = p(z)`, this `Op` returns
+ the (biased for finite `n`) estimate:
+
+ ```
+ (1 - alpha)^{-1} Log[ n^{-1} sum_{i=1}^n ( p(z_i) / q(z_i) )^{1 - alpha},
+ \approx (1 - alpha)^{-1} Log[ E_q[ (p(Z) / q(Z))^{1 - alpha} ] ]
+ ```
+
+ This ratio appears in different contexts:
+
+ #### Renyi divergence
+
+ If `log_p(z) = Log[p(z)]` is the log prob of a distribution, and
+ `alpha > 0`, `alpha != 1`, this `Op` approximates `-1` times Renyi divergence:
+
+ ```
+ # Choose reasonably high n to limit bias, see below.
+ renyi_ratio(log_p, q, alpha, n=100)
+ \approx -1 * D_alpha[q || p], where
+ D_alpha[q || p] := (1 - alpha)^{-1} Log E_q[(p(Z) / q(Z))^{1 - alpha}]
+ ```
+
+ The Renyi (or "alpha") divergence is non-negative and equal to zero iff
+ `q = p`. Various limits of `alpha` lead to different special case results:
+
+ ```
+ alpha D_alpha[q || p]
+ ----- ---------------
+ --> 0 Log[ int_{q > 0} p(z) dz ]
+ = 0.5, -2 Log[1 - Hel^2[q || p]], (\propto squared Hellinger distance)
+ --> 1 KL[q || p]
+ = 2 Log[ 1 + chi^2[q || p] ], (\propto squared Chi-2 divergence)
+ --> infty Log[ max_z{q(z) / p(z)} ], (min description length principle).
+ ```
+
+ See "Renyi Divergence Variational Inference", by Li and Turner.
+
+ #### Log evidence approximation
+
+ If `log_p(z) = Log[p(z, x)]` is the log of the joint distribution `p`, this is
+ an alternative to the ELBO common in variational inference.
+
+ ```
+ L_alpha(q, p) = Log[p(x)] - D_alpha[q || p]
+ ```
+
+ If `q` and `p` have the same support, and `0 < a <= b < 1`, one can show
+ `ELBO <= D_b <= D_a <= Log[p(x)]`. Thus, this `Op` allows a smooth
+ interpolation between the ELBO and the true evidence.
+
+ #### Stability notes
+
+ Note that when `1 - alpha` is not small, the ratio `(p(z) / q(z))^{1 - alpha}`
+ is subject to underflow/overflow issues. For that reason, it is evaluated in
+ log-space after centering. Nonetheless, infinite/NaN results may occur. For
+ that reason, one may wish to shrink `alpha` gradually. See the `Op`
+ `renyi_alpha`. Using `float64` will also help.
+
+
+ #### Bias for finite sample size
+
+ Due to nonlinearity of the logarithm, for random variables `{X_1,...,X_n}`,
+ `E[ Log[sum_{i=1}^n X_i] ] != Log[ E[sum_{i=1}^n X_i] ]`. As a result, this
+ estimate is biased for finite `n`. For `alpha < 1`, it is non-decreasing
+ with `n` (in expectation). For example, if `n = 1`, this estimator yields the
+ same result as `elbo_ratio`, and as `n` increases the expected value
+ of the estimator increases.
+
+ #### Call signature
+
+ User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
+
+ Args:
+ log_p: Callable mapping samples from `q` to `Tensors` with
+ shape broadcastable to `q.batch_shape`.
+ For example, `log_p` works "just like" `q.log_prob`.
+ q: `tf.contrib.distributions.Distribution`.
+ `float64` `dtype` recommended.
+ `log_p` and `q` should be supported on the same set.
+ alpha: `Tensor` with shape `q.batch_shape` and values not equal to 1.
+ z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
+ n: Integer `Tensor`. The number of samples to use if `z` is not provided.
+ Note that this can be highly biased for small `n`, see docstring.
+ seed: Python integer to seed the random number generator.
+ name: A name to give this `Op`.
+
+ Returns:
+ renyi_result: The scaled log of sample mean. `Tensor` with `shape` equal
+ to batch shape of `q`, and `dtype` = `q.dtype`.
+ """
+ with ops.name_scope(name, values=[alpha, n, z]):
+ z = _get_samples(q, z, n, seed)
+
+ # Evaluate sample mean in logspace. Note that _logspace_mean will compute
+ # (among other things) the mean of q.log_prob(z), which could also be
+ # obtained with q.entropy(). However, DON'T use analytic entropy, because
+ # that increases variance, and could result in NaN/Inf values of a sensitive
+ # term.
+
+ # log_values
+ # = (1 - alpha) * ( Log p - Log q )
+ log_values = (1. - alpha) * (log_p(z) - q.log_prob(z))
+
+ # log_mean_values
+ # = Log[ E[ values ] ]
+ # = Log[ E[ (p / q)^{1-alpha} ] ]
+ log_mean_values = _logspace_mean(log_values)
+
+ return log_mean_values / (1. - alpha)
+
+
+def renyi_alpha(step,
+ decay_time,
+ alpha_min,
+ alpha_max=0.99999,
+ name='renyi_alpha'):
+ r"""Exponentially decaying `Tensor` appropriate for Renyi ratios.
+
+ When minimizing the Renyi divergence for `0 <= alpha < 1` (or maximizing the
+ Renyi equivalent of elbo) in high dimensions, it is not uncommon to experience
+ `NaN` and `inf` values when `alpha` is far from `1`.
+
+ For that reason, it is often desirable to start the optimization with `alpha`
+ very close to 1, and reduce it to a final `alpha_min` according to some
+ schedule. The user may even want to optimize using `elbo_ratio` for
+ some fixed time before switching to Renyi based methods.
+
+ This `Op` returns an `alpha` decaying exponentially with step:
+
+ ```
+ s(step) = (exp{step / decay_time} - 1) / (e - 1)
+ t(s) = max(0, min(s, 1)), (smooth growth from 0 to 1)
+ alpha(t) = (1 - t) alpha_min + t alpha_max
+ ```
+
+ Args:
+ step: Non-negative scalar `Tensor`. Typically the global step or an
+ offset version thereof.
+ decay_time: Positive scalar `Tensor`.
+ alpha_min: `float` or `double` `Tensor`.
+ The minimal, final value of `alpha`, achieved when `step >= decay_time`
+ alpha_max: `Tensor` of same `dtype` as `alpha_min`.
+ The maximal, beginning value of `alpha`, achieved when `step == 0`
+ name: A name to give this `Op`.
+
+ Returns:
+ alpha: A `Tensor` of same `dtype` as `alpha_min`.
+ """
+ with ops.name_scope(name, values=[step, decay_time, alpha_min, alpha_max]):
+ alpha_min = ops.convert_to_tensor(alpha_min, name='alpha_min')
+ dtype = alpha_min.dtype
+
+ alpha_max = ops.convert_to_tensor(alpha_max, dtype=dtype, name='alpha_max')
+ decay_time = math_ops.cast(decay_time, dtype)
+ step = math_ops.cast(step, dtype)
+
+ check_scalars = [
+ check_ops.assert_rank(step, 0, message='step must be scalar'),
+ check_ops.assert_rank(
+ decay_time, 0, message='decay_time must be scalar'),
+ check_ops.assert_rank(alpha_min, 0, message='alpha_min must be scalar'),
+ check_ops.assert_rank(alpha_max, 0, message='alpha_max must be scalar'),
+ ]
+ check_sign = [
+ check_ops.assert_non_negative(
+ step, message='step must be non-negative'),
+ check_ops.assert_positive(
+ decay_time, message='decay_time must be positive'),
+ ]
+
+ with ops.control_dependencies(check_scalars + check_sign):
+ theta = (math_ops.exp(step / decay_time) - 1.) / (math.e - 1.)
+ theta = math_ops.minimum(math_ops.maximum(theta, 0.), 1.)
+ return alpha_max * (1. - theta) + alpha_min * theta
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
index 55e0e6d57b..5770bcdd70 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
@@ -15,260 +15,22 @@
"""Monte Carlo integration and helpers.
See the @{$python/contrib.bayesflow.monte_carlo} guide.
-
-@@expectation
-@@expectation_importance_sampler
-@@expectation_importance_sampler_logspace
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
-__all__ = [
+_allowed_symbols = [
'expectation',
'expectation_importance_sampler',
'expectation_importance_sampler_logspace',
]
-
-def expectation_importance_sampler(f,
- log_p,
- sampling_dist_q,
- z=None,
- n=None,
- seed=None,
- name='expectation_importance_sampler'):
- r"""Monte Carlo estimate of `E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]`.
-
- With `p(z) := exp{log_p(z)}`, this `Op` returns
-
- ```
- n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,
- \approx E_q[ f(Z) p(Z) / q(Z) ]
- = E_p[f(Z)]
- ```
-
- This integral is done in log-space with max-subtraction to better handle the
- often extreme values that `f(z) p(z) / q(z)` can take on.
-
- If `f >= 0`, it is up to 2x more efficient to exponentiate the result of
- `expectation_importance_sampler_logspace` applied to `Log[f]`.
-
- User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
-
- Args:
- f: Callable mapping samples from `sampling_dist_q` to `Tensors` with shape
- broadcastable to `q.batch_shape`.
- For example, `f` works "just like" `q.log_prob`.
- log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
- shape broadcastable to `q.batch_shape`.
- For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
- sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
- `float64` `dtype` recommended.
- `log_p` and `q` should be supported on the same set.
- z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
- n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
- seed: Python integer to seed the random number generator.
- name: A name to give this `Op`.
-
- Returns:
- The importance sampling estimate. `Tensor` with `shape` equal
- to batch shape of `q`, and `dtype` = `q.dtype`.
- """
- q = sampling_dist_q
- with ops.name_scope(name, values=[z, n]):
- z = _get_samples(q, z, n, seed)
-
- log_p_z = log_p(z)
- q_log_prob_z = q.log_prob(z)
-
- def _importance_sampler_positive_f(log_f_z):
- # Same as expectation_importance_sampler_logspace, but using Tensors
- # rather than samples and functions. Allows us to sample once.
- log_values = log_f_z + log_p_z - q_log_prob_z
- return _logspace_mean(log_values)
-
- # With f_plus(z) = max(0, f(z)), f_minus(z) = max(0, -f(z)),
- # E_p[f(Z)] = E_p[f_plus(Z)] - E_p[f_minus(Z)]
- # = E_p[f_plus(Z) + 1] - E_p[f_minus(Z) + 1]
- # Without incurring bias, 1 is added to each to prevent zeros in logspace.
- # The logarithm is approximately linear around 1 + epsilon, so this is good
- # for small values of 'z' as well.
- f_z = f(z)
- log_f_plus_z = math_ops.log(nn.relu(f_z) + 1.)
- log_f_minus_z = math_ops.log(nn.relu(-1. * f_z) + 1.)
-
- log_f_plus_integral = _importance_sampler_positive_f(log_f_plus_z)
- log_f_minus_integral = _importance_sampler_positive_f(log_f_minus_z)
-
- return math_ops.exp(log_f_plus_integral) - math_ops.exp(log_f_minus_integral)
-
-
-def expectation_importance_sampler_logspace(
- log_f,
- log_p,
- sampling_dist_q,
- z=None,
- n=None,
- seed=None,
- name='expectation_importance_sampler_logspace'):
- r"""Importance sampling with a positive function, in log-space.
-
- With `p(z) := exp{log_p(z)}`, and `f(z) = exp{log_f(z)}`, this `Op`
- returns
-
- ```
- Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,
- \approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]
- = Log[E_p[f(Z)]]
- ```
-
- This integral is done in log-space with max-subtraction to better handle the
- often extreme values that `f(z) p(z) / q(z)` can take on.
-
- In contrast to `expectation_importance_sampler`, this `Op` returns values in
- log-space.
-
-
- User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
-
- Args:
- log_f: Callable mapping samples from `sampling_dist_q` to `Tensors` with
- shape broadcastable to `q.batch_shape`.
- For example, `log_f` works "just like" `sampling_dist_q.log_prob`.
- log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
- shape broadcastable to `q.batch_shape`.
- For example, `log_p` works "just like" `q.log_prob`.
- sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
- `float64` `dtype` recommended.
- `log_p` and `q` should be supported on the same set.
- z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
- n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
- seed: Python integer to seed the random number generator.
- name: A name to give this `Op`.
-
- Returns:
- Logarithm of the importance sampling estimate. `Tensor` with `shape` equal
- to batch shape of `q`, and `dtype` = `q.dtype`.
- """
- q = sampling_dist_q
- with ops.name_scope(name, values=[z, n]):
- z = _get_samples(q, z, n, seed)
- log_values = log_f(z) + log_p(z) - q.log_prob(z)
- return _logspace_mean(log_values)
-
-
-def _logspace_mean(log_values):
- """Evaluate `Log[E[values]]` in a stable manner.
-
- Args:
- log_values: `Tensor` holding `Log[values]`.
-
- Returns:
- `Tensor` of same `dtype` as `log_values`, reduced across dim 0.
- `Log[Mean[values]]`.
- """
- # center = Max[Log[values]], with stop-gradient
- # The center hopefully keep the exponentiated term small. It is cancelled
- # from the final result, so putting stop gradient on it will not change the
- # final result. We put stop gradient on to eliminate unnecessary computation.
- center = array_ops.stop_gradient(_sample_max(log_values))
-
- # centered_values = exp{Log[values] - E[Log[values]]}
- centered_values = math_ops.exp(log_values - center)
-
- # log_mean_of_values = Log[ E[centered_values] ] + center
- # = Log[ E[exp{log_values - E[log_values]}] ] + center
- # = Log[E[values]] - E[log_values] + center
- # = Log[E[values]]
- log_mean_of_values = math_ops.log(_sample_mean(centered_values)) + center
-
- return log_mean_of_values
-
-
-def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
- r"""Monte Carlo estimate of an expectation: `E_p[f(Z)]` with sample mean.
-
- This `Op` returns
-
- ```
- n^{-1} sum_{i=1}^n f(z_i), where z_i ~ p
- \approx E_p[f(Z)]
- ```
-
- User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
-
- Args:
- f: Callable mapping samples from `p` to `Tensors`.
- p: `tf.contrib.distributions.Distribution`.
- z: `Tensor` of samples from `p`, produced by `p.sample` for some `n`.
- n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
- seed: Python integer to seed the random number generator.
- name: A name to give this `Op`.
-
- Returns:
- A `Tensor` with the same `dtype` as `p`.
-
- Example:
-
- ```python
- N_samples = 10000
-
- distributions = tf.contrib.distributions
-
- dist = distributions.Uniform([0.0, 0.0], [1.0, 2.0])
- elementwise_mean = lambda x: x
- mean_sum = lambda x: tf.reduce_sum(x, 1)
-
- estimate_elementwise_mean_tf = monte_carlo.expectation(elementwise_mean,
- dist,
- n=N_samples)
- estimate_mean_sum_tf = monte_carlo.expectation(mean_sum,
- dist,
- n=N_samples)
-
- with tf.Session() as sess:
- estimate_elementwise_mean, estimate_mean_sum = (
- sess.run([estimate_elementwise_mean_tf, estimate_mean_sum_tf]))
- print estimate_elementwise_mean
- >>> np.array([ 0.50018013 1.00097895], dtype=np.float32)
- print estimate_mean_sum
- >>> 1.49571
-
- ```
-
- """
- with ops.name_scope(name, values=[n, z]):
- z = _get_samples(p, z, n, seed)
- return _sample_mean(f(z))
-
-
-def _sample_mean(values):
- """Mean over sample indices. In this module this is always [0]."""
- return math_ops.reduce_mean(values, reduction_indices=[0])
-
-
-def _sample_max(values):
- """Max over sample indices. In this module this is always [0]."""
- return math_ops.reduce_max(values, reduction_indices=[0])
-
-
-def _get_samples(dist, z, n, seed):
- """Check args and return samples."""
- with ops.name_scope('get_samples', values=[z, n]):
- if (n is None) == (z is None):
- raise ValueError(
- 'Must specify exactly one of arguments "n" and "z". Found: '
- 'n = %s, z = %s' % (n, z))
- if n is not None:
- return dist.sample(n, seed=seed)
- else:
- return ops.convert_to_tensor(z, name='z')
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
new file mode 100644
index 0000000000..a8654bcf31
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -0,0 +1,274 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Monte Carlo integration and helpers.
+
+See the ${@python/contrib.bayesflow.monte_carlo} guide.
+
+@@expectation
+@@expectation_importance_sampler
+@@expectation_importance_sampler_logspace
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+
+__all__ = [
+ 'expectation',
+ 'expectation_importance_sampler',
+ 'expectation_importance_sampler_logspace',
+]
+
+
+def expectation_importance_sampler(f,
+ log_p,
+ sampling_dist_q,
+ z=None,
+ n=None,
+ seed=None,
+ name='expectation_importance_sampler'):
+ r"""Monte Carlo estimate of `E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]`.
+
+ With `p(z) := exp{log_p(z)}`, this `Op` returns
+
+ ```
+ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,
+ \approx E_q[ f(Z) p(Z) / q(Z) ]
+ = E_p[f(Z)]
+ ```
+
+ This integral is done in log-space with max-subtraction to better handle the
+ often extreme values that `f(z) p(z) / q(z)` can take on.
+
+ If `f >= 0`, it is up to 2x more efficient to exponentiate the result of
+ `expectation_importance_sampler_logspace` applied to `Log[f]`.
+
+ User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
+
+ Args:
+ f: Callable mapping samples from `sampling_dist_q` to `Tensors` with shape
+ broadcastable to `q.batch_shape`.
+ For example, `f` works "just like" `q.log_prob`.
+ log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
+ shape broadcastable to `q.batch_shape`.
+ For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
+ sampling_dist_q: The sampling distribution.
+ `tf.contrib.distributions.Distribution`.
+ `float64` `dtype` recommended.
+ `log_p` and `q` should be supported on the same set.
+ z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
+ n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
+ seed: Python integer to seed the random number generator.
+ name: A name to give this `Op`.
+
+ Returns:
+ The importance sampling estimate. `Tensor` with `shape` equal
+ to batch shape of `q`, and `dtype` = `q.dtype`.
+ """
+ q = sampling_dist_q
+ with ops.name_scope(name, values=[z, n]):
+ z = _get_samples(q, z, n, seed)
+
+ log_p_z = log_p(z)
+ q_log_prob_z = q.log_prob(z)
+
+ def _importance_sampler_positive_f(log_f_z):
+ # Same as expectation_importance_sampler_logspace, but using Tensors
+ # rather than samples and functions. Allows us to sample once.
+ log_values = log_f_z + log_p_z - q_log_prob_z
+ return _logspace_mean(log_values)
+
+ # With f_plus(z) = max(0, f(z)), f_minus(z) = max(0, -f(z)),
+ # E_p[f(Z)] = E_p[f_plus(Z)] - E_p[f_minus(Z)]
+ # = E_p[f_plus(Z) + 1] - E_p[f_minus(Z) + 1]
+ # Without incurring bias, 1 is added to each to prevent zeros in logspace.
+ # The logarithm is approximately linear around 1 + epsilon, so this is good
+ # for small values of 'z' as well.
+ f_z = f(z)
+ log_f_plus_z = math_ops.log(nn.relu(f_z) + 1.)
+ log_f_minus_z = math_ops.log(nn.relu(-1. * f_z) + 1.)
+
+ log_f_plus_integral = _importance_sampler_positive_f(log_f_plus_z)
+ log_f_minus_integral = _importance_sampler_positive_f(log_f_minus_z)
+
+ return math_ops.exp(log_f_plus_integral) - math_ops.exp(log_f_minus_integral)
+
+
+def expectation_importance_sampler_logspace(
+ log_f,
+ log_p,
+ sampling_dist_q,
+ z=None,
+ n=None,
+ seed=None,
+ name='expectation_importance_sampler_logspace'):
+ r"""Importance sampling with a positive function, in log-space.
+
+ With `p(z) := exp{log_p(z)}`, and `f(z) = exp{log_f(z)}`, this `Op`
+ returns
+
+ ```
+ Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,
+ \approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]
+ = Log[E_p[f(Z)]]
+ ```
+
+ This integral is done in log-space with max-subtraction to better handle the
+ often extreme values that `f(z) p(z) / q(z)` can take on.
+
+ In contrast to `expectation_importance_sampler`, this `Op` returns values in
+ log-space.
+
+
+ User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
+
+ Args:
+ log_f: Callable mapping samples from `sampling_dist_q` to `Tensors` with
+ shape broadcastable to `q.batch_shape`.
+ For example, `log_f` works "just like" `sampling_dist_q.log_prob`.
+ log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
+ shape broadcastable to `q.batch_shape`.
+ For example, `log_p` works "just like" `q.log_prob`.
+ sampling_dist_q: The sampling distribution.
+ `tf.contrib.distributions.Distribution`.
+ `float64` `dtype` recommended.
+ `log_p` and `q` should be supported on the same set.
+ z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
+ n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
+ seed: Python integer to seed the random number generator.
+ name: A name to give this `Op`.
+
+ Returns:
+ Logarithm of the importance sampling estimate. `Tensor` with `shape` equal
+ to batch shape of `q`, and `dtype` = `q.dtype`.
+ """
+ q = sampling_dist_q
+ with ops.name_scope(name, values=[z, n]):
+ z = _get_samples(q, z, n, seed)
+ log_values = log_f(z) + log_p(z) - q.log_prob(z)
+ return _logspace_mean(log_values)
+
+
+def _logspace_mean(log_values):
+ """Evaluate `Log[E[values]]` in a stable manner.
+
+ Args:
+ log_values: `Tensor` holding `Log[values]`.
+
+ Returns:
+ `Tensor` of same `dtype` as `log_values`, reduced across dim 0.
+ `Log[Mean[values]]`.
+ """
+ # center = Max[Log[values]], with stop-gradient
+ # The center hopefully keep the exponentiated term small. It is cancelled
+ # from the final result, so putting stop gradient on it will not change the
+ # final result. We put stop gradient on to eliminate unnecessary computation.
+ center = array_ops.stop_gradient(_sample_max(log_values))
+
+ # centered_values = exp{Log[values] - E[Log[values]]}
+ centered_values = math_ops.exp(log_values - center)
+
+ # log_mean_of_values = Log[ E[centered_values] ] + center
+ # = Log[ E[exp{log_values - E[log_values]}] ] + center
+ # = Log[E[values]] - E[log_values] + center
+ # = Log[E[values]]
+ log_mean_of_values = math_ops.log(_sample_mean(centered_values)) + center
+
+ return log_mean_of_values
+
+
+def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
+ r"""Monte Carlo estimate of an expectation: `E_p[f(Z)]` with sample mean.
+
+ This `Op` returns
+
+ ```
+ n^{-1} sum_{i=1}^n f(z_i), where z_i ~ p
+ \approx E_p[f(Z)]
+ ```
+
+ User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
+
+ Args:
+ f: Callable mapping samples from `p` to `Tensors`.
+ p: `tf.contrib.distributions.Distribution`.
+ z: `Tensor` of samples from `p`, produced by `p.sample` for some `n`.
+ n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
+ seed: Python integer to seed the random number generator.
+ name: A name to give this `Op`.
+
+ Returns:
+ A `Tensor` with the same `dtype` as `p`.
+
+ Example:
+
+ ```python
+ N_samples = 10000
+
+ distributions = tf.contrib.distributions
+
+ dist = distributions.Uniform([0.0, 0.0], [1.0, 2.0])
+ elementwise_mean = lambda x: x
+ mean_sum = lambda x: tf.reduce_sum(x, 1)
+
+ estimate_elementwise_mean_tf = monte_carlo.expectation(elementwise_mean,
+ dist,
+ n=N_samples)
+ estimate_mean_sum_tf = monte_carlo.expectation(mean_sum,
+ dist,
+ n=N_samples)
+
+ with tf.Session() as sess:
+ estimate_elementwise_mean, estimate_mean_sum = (
+ sess.run([estimate_elementwise_mean_tf, estimate_mean_sum_tf]))
+ print estimate_elementwise_mean
+ >>> np.array([ 0.50018013 1.00097895], dtype=np.float32)
+ print estimate_mean_sum
+ >>> 1.49571
+
+ ```
+
+ """
+ with ops.name_scope(name, values=[n, z]):
+ z = _get_samples(p, z, n, seed)
+ return _sample_mean(f(z))
+
+
+def _sample_mean(values):
+ """Mean over sample indices. In this module this is always [0]."""
+ return math_ops.reduce_mean(values, reduction_indices=[0])
+
+
+def _sample_max(values):
+ """Max over sample indices. In this module this is always [0]."""
+ return math_ops.reduce_max(values, reduction_indices=[0])
+
+
+def _get_samples(dist, z, n, seed):
+ """Check args and return samples."""
+ with ops.name_scope('get_samples', values=[z, n]):
+ if (n is None) == (z is None):
+ raise ValueError(
+ 'Must specify exactly one of arguments "n" and "z". Found: '
+ 'n = %s, z = %s' % (n, z))
+ if n is not None:
+ return dist.sample(n, seed=seed)
+ else:
+ return ops.convert_to_tensor(z, name='z')
diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_inference.py b/tensorflow/contrib/bayesflow/python/ops/variational_inference.py
index 17a8666686..6316361da2 100644
--- a/tensorflow/contrib/bayesflow/python/ops/variational_inference.py
+++ b/tensorflow/contrib/bayesflow/python/ops/variational_inference.py
@@ -15,313 +15,20 @@
"""Variational inference.
See the ${@python/contrib.bayesflow.variational_inference} guide.
-
-@@elbo
-@@elbo_with_log_joint
-@@ELBOForms
-@@register_prior
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.bayesflow.python.ops import stochastic_graph as sg
-from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor as st
-from tensorflow.contrib.distributions.python.ops import distribution
-from tensorflow.contrib.distributions.python.ops import kullback_leibler
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import tf_logging as logging
-
-VI_PRIORS = "__vi_priors__"
-
-
-def register_prior(variational, prior):
- """Associate a variational `StochasticTensor` with a `Distribution` prior.
-
- This is a helper function used in conjunction with `elbo` that allows users
- to specify the mapping between variational distributions and their priors
- without having to pass in `variational_with_prior` explicitly.
-
- Args:
- variational: `StochasticTensor` q(Z). Approximating distribution.
- prior: `Distribution` p(Z). Prior distribution.
-
- Returns:
- None
-
- Raises:
- ValueError: if variational is not a `StochasticTensor` or `prior` is not
- a `Distribution`.
- """
- if not isinstance(variational, st.StochasticTensor):
- raise TypeError("variational must be a StochasticTensor")
- if not isinstance(prior, distribution.Distribution):
- raise TypeError("prior must be a Distribution")
- ops.add_to_collection(VI_PRIORS, (variational, prior))
-
-
-class _ELBOForm(object):
- pass
-
-
-class ELBOForms(object):
- """Constants to control the `elbo` calculation.
-
- `analytic_kl` uses the analytic KL divergence between the
- variational distribution(s) and the prior(s).
-
- `analytic_entropy` uses the analytic entropy of the variational
- distribution(s).
-
- `sample` uses the sample KL or the sample entropy is the joint is provided.
-
- See `elbo` for what is used with `default`.
- """
- default, analytic_kl, analytic_entropy, sample = (_ELBOForm()
- for _ in range(4))
-
- @staticmethod
- def check_form(form):
- if form not in {
- ELBOForms.default, ELBOForms.analytic_kl, ELBOForms.analytic_entropy,
- ELBOForms.sample
- }:
- raise TypeError("form must be an ELBOForms constant")
-
-
-def elbo(log_likelihood,
- variational_with_prior=None,
- keep_batch_dim=True,
- form=None,
- name="ELBO"):
- r"""Evidence Lower BOund. `log p(x) >= ELBO`.
-
- Optimization objective for inference of hidden variables by variational
- inference.
-
- This function is meant to be used in conjunction with `StochasticTensor`.
- The user should build out the inference network, using `StochasticTensor`s
- as latent variables, and the generative network. `elbo` at minimum needs
- `p(x|Z)` and assumes that all `StochasticTensor`s upstream of `p(x|Z)` are
- the variational distributions. Use `register_prior` to register `Distribution`
- priors for each `StochasticTensor`. Alternatively, pass in
- `variational_with_prior` specifying all variational distributions and their
- priors.
-
- Mathematical details:
-
- ```
- log p(x) = log \int p(x, Z) dZ
- = log \int \frac {q(Z)p(x, Z)}{q(Z)} dZ
- = log E_q[\frac {p(x, Z)}{q(Z)}]
- >= E_q[log \frac {p(x, Z)}{q(Z)}] = L[q; p, x] # ELBO
-
- L[q; p, x] = E_q[log p(x|Z)p(Z)] - E_q[log q(Z)]
- = E_q[log p(x|Z)p(Z)] + H[q] (1)
- = E_q[log p(x|Z)] - KL(q || p) (2)
-
- H - Entropy
- KL - Kullback-Leibler divergence
- ```
-
- See section 2.2 of Stochastic Variational Inference by Hoffman et al. for
- more, including the ELBO's equivalence to minimizing `KL(q(Z)||p(Z|x))`
- in the fully Bayesian setting. https://arxiv.org/pdf/1206.7051.pdf.
-
- `form` specifies which form of the ELBO is used. `form=ELBOForms.default`
- tries, in order of preference: analytic KL, analytic entropy, sampling.
-
- Multiple entries in the `variational_with_prior` dict implies a factorization.
- e.g. `q(Z) = q(z1)q(z2)q(z3)`.
-
- Args:
- log_likelihood: `Tensor` log p(x|Z).
- variational_with_prior: dict from `StochasticTensor` q(Z) to
- `Distribution` p(Z). If `None`, defaults to all `StochasticTensor`
- objects upstream of `log_likelihood` with priors registered with
- `register_prior`.
- keep_batch_dim: bool. Whether to keep the batch dimension when summing
- entropy/KL term. When the sample is per data point, this should be True;
- otherwise (e.g. in a Bayesian NN), this should be False.
- form: ELBOForms constant. Controls how the ELBO is computed. Defaults to
- ELBOForms.default.
- name: name to prefix ops with.
-
- Returns:
- `Tensor` ELBO of the same type and shape as `log_likelihood`.
-
- Raises:
- TypeError: if variationals in `variational_with_prior` are not
- `StochasticTensor`s or if priors are not `Distribution`s.
- TypeError: if form is not a valid ELBOForms constant.
- ValueError: if `variational_with_prior` is None and there are no
- `StochasticTensor`s upstream of `log_likelihood`.
- ValueError: if any variational does not have a prior passed or registered.
- """
- if form is None:
- form = ELBOForms.default
- with ops.name_scope(name):
- model = ops.convert_to_tensor(log_likelihood)
- variational_with_prior = _find_variational_and_priors(
- model, variational_with_prior)
- return _elbo(form, log_likelihood, None, variational_with_prior,
- keep_batch_dim)
-
-
-def elbo_with_log_joint(log_joint,
- variational=None,
- keep_batch_dim=True,
- form=None,
- name="ELBO"):
- """Evidence Lower BOund. `log p(x) >= ELBO`.
-
- This method is for models that have computed `p(x,Z)` instead of `p(x|Z)`.
- See `elbo` for further details.
-
- Because only the joint is specified, analytic KL is not available.
-
- Args:
- log_joint: `Tensor` log p(x, Z).
- variational: list of `StochasticTensor` q(Z). If `None`, defaults to all
- `StochasticTensor` objects upstream of `log_joint`.
- keep_batch_dim: bool. Whether to keep the batch dimension when summing
- entropy term. When the sample is per data point, this should be True;
- otherwise (e.g. in a Bayesian NN), this should be False.
- form: ELBOForms constant. Controls how the ELBO is computed. Defaults to
- ELBOForms.default.
- name: name to prefix ops with.
-
- Returns:
- `Tensor` ELBO of the same type and shape as `log_joint`.
-
- Raises:
- TypeError: if variationals in `variational` are not `StochasticTensor`s.
- TypeError: if form is not a valid ELBOForms constant.
- ValueError: if `variational` is None and there are no `StochasticTensor`s
- upstream of `log_joint`.
- ValueError: if form is ELBOForms.analytic_kl.
- """
- if form is None:
- form = ELBOForms.default
- if form == ELBOForms.analytic_kl:
- raise ValueError("ELBOForms.analytic_kl is not available when using "
- "elbo_with_log_joint. Use elbo or a different form.")
-
- with ops.name_scope(name):
- model = ops.convert_to_tensor(log_joint)
-
- variational_with_prior = None
- if variational is not None:
- variational_with_prior = dict(zip(variational, [None] * len(variational)))
- variational_with_prior = _find_variational_and_priors(
- model, variational_with_prior, require_prior=False)
- return _elbo(form, None, log_joint, variational_with_prior, keep_batch_dim)
-
-
-def _elbo(form, log_likelihood, log_joint, variational_with_prior,
- keep_batch_dim):
- """Internal implementation of ELBO. Users should use `elbo`.
-
- Args:
- form: ELBOForms constant. Controls how the ELBO is computed.
- log_likelihood: `Tensor` log p(x|Z).
- log_joint: `Tensor` log p(x, Z).
- variational_with_prior: `dict<StochasticTensor, Distribution>`, varational
- distributions to prior distributions.
- keep_batch_dim: bool. Whether to keep the batch dimension when reducing
- the entropy/KL.
-
- Returns:
- ELBO `Tensor` with same shape and dtype as `log_likelihood`/`log_joint`.
- """
- ELBOForms.check_form(form)
-
- # Order of preference
- # 1. Analytic KL: log_likelihood - KL(q||p)
- # 2. Analytic entropy: log_likelihood + log p(Z) + H[q], or log_joint + H[q]
- # 3. Sample: log_likelihood - (log q(Z) - log p(Z)) =
- # log_likelihood + log p(Z) - log q(Z), or log_joint - q(Z)
-
- def _reduce(val):
- if keep_batch_dim:
- return val
- else:
- return math_ops.reduce_sum(val)
-
- kl_terms = []
- entropy_terms = []
- prior_terms = []
- for q, z, p in [(qz.distribution, qz.value(), pz)
- for qz, pz in variational_with_prior.items()]:
- # Analytic KL
- kl = None
- if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}:
- try:
- kl = kullback_leibler.kl(q, p)
- logging.info("Using analytic KL between q:%s, p:%s", q, p)
- except NotImplementedError as e:
- if form == ELBOForms.analytic_kl:
- raise e
- if kl is not None:
- kl_terms.append(-1. * _reduce(kl))
- continue
-
- # Analytic entropy
- entropy = None
- if form in {ELBOForms.default, ELBOForms.analytic_entropy}:
- try:
- entropy = q.entropy()
- logging.info("Using analytic entropy for q:%s", q)
- except NotImplementedError as e:
- if form == ELBOForms.analytic_entropy:
- raise e
- if entropy is not None:
- entropy_terms.append(_reduce(entropy))
- if log_likelihood is not None:
- prior = p.log_prob(z)
- prior_terms.append(_reduce(prior))
- continue
-
- # Sample
- if form in {ELBOForms.default, ELBOForms.sample}:
- entropy = -q.log_prob(z)
- entropy_terms.append(_reduce(entropy))
- if log_likelihood is not None:
- prior = p.log_prob(z)
- prior_terms.append(_reduce(prior))
-
- first_term = log_joint if log_joint is not None else log_likelihood
- return sum([first_term] + kl_terms + entropy_terms + prior_terms)
-
-
-def _find_variational_and_priors(model,
- variational_with_prior,
- require_prior=True):
- """Find upstream StochasticTensors and match with registered priors."""
- if variational_with_prior is None:
- # pylint: disable=protected-access
- upstreams = sg._upstream_stochastic_nodes([model])
- # pylint: enable=protected-access
- upstreams = list(upstreams[model])
- if not upstreams:
- raise ValueError("No upstream stochastic nodes found for tensor: %s",
- model)
- prior_map = dict(ops.get_collection(VI_PRIORS))
- variational_with_prior = {}
- for q in upstreams:
- if require_prior and (q not in prior_map or prior_map[q] is None):
- raise ValueError("No prior specified for StochasticTensor: %s", q)
- variational_with_prior[q] = prior_map.get(q)
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.variational_inference_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
- if not all(
- [isinstance(q, st.StochasticTensor) for q in variational_with_prior]):
- raise TypeError("variationals must be StochasticTensors")
- if not all([
- p is None or isinstance(p, distribution.Distribution)
- for p in variational_with_prior.values()
- ]):
- raise TypeError("priors must be Distribution objects")
+_allowed_symbols = [
+ "elbo", "elbo_with_log_joint", "ELBOForms", "register_prior"
+]
- return variational_with_prior
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py b/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py
new file mode 100644
index 0000000000..17a8666686
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py
@@ -0,0 +1,327 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Variational inference.
+
+See the ${@python/contrib.bayesflow.variational_inference} guide.
+
+@@elbo
+@@elbo_with_log_joint
+@@ELBOForms
+@@register_prior
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.bayesflow.python.ops import stochastic_graph as sg
+from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor as st
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import kullback_leibler
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging as logging
+
+VI_PRIORS = "__vi_priors__"
+
+
+def register_prior(variational, prior):
+ """Associate a variational `StochasticTensor` with a `Distribution` prior.
+
+ This is a helper function used in conjunction with `elbo` that allows users
+ to specify the mapping between variational distributions and their priors
+ without having to pass in `variational_with_prior` explicitly.
+
+ Args:
+ variational: `StochasticTensor` q(Z). Approximating distribution.
+ prior: `Distribution` p(Z). Prior distribution.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: if variational is not a `StochasticTensor` or `prior` is not
+ a `Distribution`.
+ """
+ if not isinstance(variational, st.StochasticTensor):
+ raise TypeError("variational must be a StochasticTensor")
+ if not isinstance(prior, distribution.Distribution):
+ raise TypeError("prior must be a Distribution")
+ ops.add_to_collection(VI_PRIORS, (variational, prior))
+
+
+class _ELBOForm(object):
+ pass
+
+
+class ELBOForms(object):
+ """Constants to control the `elbo` calculation.
+
+ `analytic_kl` uses the analytic KL divergence between the
+ variational distribution(s) and the prior(s).
+
+ `analytic_entropy` uses the analytic entropy of the variational
+ distribution(s).
+
+ `sample` uses the sample KL or the sample entropy is the joint is provided.
+
+ See `elbo` for what is used with `default`.
+ """
+ default, analytic_kl, analytic_entropy, sample = (_ELBOForm()
+ for _ in range(4))
+
+ @staticmethod
+ def check_form(form):
+ if form not in {
+ ELBOForms.default, ELBOForms.analytic_kl, ELBOForms.analytic_entropy,
+ ELBOForms.sample
+ }:
+ raise TypeError("form must be an ELBOForms constant")
+
+
+def elbo(log_likelihood,
+ variational_with_prior=None,
+ keep_batch_dim=True,
+ form=None,
+ name="ELBO"):
+ r"""Evidence Lower BOund. `log p(x) >= ELBO`.
+
+ Optimization objective for inference of hidden variables by variational
+ inference.
+
+ This function is meant to be used in conjunction with `StochasticTensor`.
+ The user should build out the inference network, using `StochasticTensor`s
+ as latent variables, and the generative network. `elbo` at minimum needs
+ `p(x|Z)` and assumes that all `StochasticTensor`s upstream of `p(x|Z)` are
+ the variational distributions. Use `register_prior` to register `Distribution`
+ priors for each `StochasticTensor`. Alternatively, pass in
+ `variational_with_prior` specifying all variational distributions and their
+ priors.
+
+ Mathematical details:
+
+ ```
+ log p(x) = log \int p(x, Z) dZ
+ = log \int \frac {q(Z)p(x, Z)}{q(Z)} dZ
+ = log E_q[\frac {p(x, Z)}{q(Z)}]
+ >= E_q[log \frac {p(x, Z)}{q(Z)}] = L[q; p, x] # ELBO
+
+ L[q; p, x] = E_q[log p(x|Z)p(Z)] - E_q[log q(Z)]
+ = E_q[log p(x|Z)p(Z)] + H[q] (1)
+ = E_q[log p(x|Z)] - KL(q || p) (2)
+
+ H - Entropy
+ KL - Kullback-Leibler divergence
+ ```
+
+ See section 2.2 of Stochastic Variational Inference by Hoffman et al. for
+ more, including the ELBO's equivalence to minimizing `KL(q(Z)||p(Z|x))`
+ in the fully Bayesian setting. https://arxiv.org/pdf/1206.7051.pdf.
+
+ `form` specifies which form of the ELBO is used. `form=ELBOForms.default`
+ tries, in order of preference: analytic KL, analytic entropy, sampling.
+
+ Multiple entries in the `variational_with_prior` dict implies a factorization.
+ e.g. `q(Z) = q(z1)q(z2)q(z3)`.
+
+ Args:
+ log_likelihood: `Tensor` log p(x|Z).
+ variational_with_prior: dict from `StochasticTensor` q(Z) to
+ `Distribution` p(Z). If `None`, defaults to all `StochasticTensor`
+ objects upstream of `log_likelihood` with priors registered with
+ `register_prior`.
+ keep_batch_dim: bool. Whether to keep the batch dimension when summing
+ entropy/KL term. When the sample is per data point, this should be True;
+ otherwise (e.g. in a Bayesian NN), this should be False.
+ form: ELBOForms constant. Controls how the ELBO is computed. Defaults to
+ ELBOForms.default.
+ name: name to prefix ops with.
+
+ Returns:
+ `Tensor` ELBO of the same type and shape as `log_likelihood`.
+
+ Raises:
+ TypeError: if variationals in `variational_with_prior` are not
+ `StochasticTensor`s or if priors are not `Distribution`s.
+ TypeError: if form is not a valid ELBOForms constant.
+ ValueError: if `variational_with_prior` is None and there are no
+ `StochasticTensor`s upstream of `log_likelihood`.
+ ValueError: if any variational does not have a prior passed or registered.
+ """
+ if form is None:
+ form = ELBOForms.default
+ with ops.name_scope(name):
+ model = ops.convert_to_tensor(log_likelihood)
+ variational_with_prior = _find_variational_and_priors(
+ model, variational_with_prior)
+ return _elbo(form, log_likelihood, None, variational_with_prior,
+ keep_batch_dim)
+
+
+def elbo_with_log_joint(log_joint,
+ variational=None,
+ keep_batch_dim=True,
+ form=None,
+ name="ELBO"):
+ """Evidence Lower BOund. `log p(x) >= ELBO`.
+
+ This method is for models that have computed `p(x,Z)` instead of `p(x|Z)`.
+ See `elbo` for further details.
+
+ Because only the joint is specified, analytic KL is not available.
+
+ Args:
+ log_joint: `Tensor` log p(x, Z).
+ variational: list of `StochasticTensor` q(Z). If `None`, defaults to all
+ `StochasticTensor` objects upstream of `log_joint`.
+ keep_batch_dim: bool. Whether to keep the batch dimension when summing
+ entropy term. When the sample is per data point, this should be True;
+ otherwise (e.g. in a Bayesian NN), this should be False.
+ form: ELBOForms constant. Controls how the ELBO is computed. Defaults to
+ ELBOForms.default.
+ name: name to prefix ops with.
+
+ Returns:
+ `Tensor` ELBO of the same type and shape as `log_joint`.
+
+ Raises:
+ TypeError: if variationals in `variational` are not `StochasticTensor`s.
+ TypeError: if form is not a valid ELBOForms constant.
+ ValueError: if `variational` is None and there are no `StochasticTensor`s
+ upstream of `log_joint`.
+ ValueError: if form is ELBOForms.analytic_kl.
+ """
+ if form is None:
+ form = ELBOForms.default
+ if form == ELBOForms.analytic_kl:
+ raise ValueError("ELBOForms.analytic_kl is not available when using "
+ "elbo_with_log_joint. Use elbo or a different form.")
+
+ with ops.name_scope(name):
+ model = ops.convert_to_tensor(log_joint)
+
+ variational_with_prior = None
+ if variational is not None:
+ variational_with_prior = dict(zip(variational, [None] * len(variational)))
+ variational_with_prior = _find_variational_and_priors(
+ model, variational_with_prior, require_prior=False)
+ return _elbo(form, None, log_joint, variational_with_prior, keep_batch_dim)
+
+
+def _elbo(form, log_likelihood, log_joint, variational_with_prior,
+ keep_batch_dim):
+ """Internal implementation of ELBO. Users should use `elbo`.
+
+ Args:
+ form: ELBOForms constant. Controls how the ELBO is computed.
+ log_likelihood: `Tensor` log p(x|Z).
+ log_joint: `Tensor` log p(x, Z).
+ variational_with_prior: `dict<StochasticTensor, Distribution>`, varational
+ distributions to prior distributions.
+ keep_batch_dim: bool. Whether to keep the batch dimension when reducing
+ the entropy/KL.
+
+ Returns:
+ ELBO `Tensor` with same shape and dtype as `log_likelihood`/`log_joint`.
+ """
+ ELBOForms.check_form(form)
+
+ # Order of preference
+ # 1. Analytic KL: log_likelihood - KL(q||p)
+ # 2. Analytic entropy: log_likelihood + log p(Z) + H[q], or log_joint + H[q]
+ # 3. Sample: log_likelihood - (log q(Z) - log p(Z)) =
+ # log_likelihood + log p(Z) - log q(Z), or log_joint - q(Z)
+
+ def _reduce(val):
+ if keep_batch_dim:
+ return val
+ else:
+ return math_ops.reduce_sum(val)
+
+ kl_terms = []
+ entropy_terms = []
+ prior_terms = []
+ for q, z, p in [(qz.distribution, qz.value(), pz)
+ for qz, pz in variational_with_prior.items()]:
+ # Analytic KL
+ kl = None
+ if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}:
+ try:
+ kl = kullback_leibler.kl(q, p)
+ logging.info("Using analytic KL between q:%s, p:%s", q, p)
+ except NotImplementedError as e:
+ if form == ELBOForms.analytic_kl:
+ raise e
+ if kl is not None:
+ kl_terms.append(-1. * _reduce(kl))
+ continue
+
+ # Analytic entropy
+ entropy = None
+ if form in {ELBOForms.default, ELBOForms.analytic_entropy}:
+ try:
+ entropy = q.entropy()
+ logging.info("Using analytic entropy for q:%s", q)
+ except NotImplementedError as e:
+ if form == ELBOForms.analytic_entropy:
+ raise e
+ if entropy is not None:
+ entropy_terms.append(_reduce(entropy))
+ if log_likelihood is not None:
+ prior = p.log_prob(z)
+ prior_terms.append(_reduce(prior))
+ continue
+
+ # Sample
+ if form in {ELBOForms.default, ELBOForms.sample}:
+ entropy = -q.log_prob(z)
+ entropy_terms.append(_reduce(entropy))
+ if log_likelihood is not None:
+ prior = p.log_prob(z)
+ prior_terms.append(_reduce(prior))
+
+ first_term = log_joint if log_joint is not None else log_likelihood
+ return sum([first_term] + kl_terms + entropy_terms + prior_terms)
+
+
+def _find_variational_and_priors(model,
+ variational_with_prior,
+ require_prior=True):
+ """Find upstream StochasticTensors and match with registered priors."""
+ if variational_with_prior is None:
+ # pylint: disable=protected-access
+ upstreams = sg._upstream_stochastic_nodes([model])
+ # pylint: enable=protected-access
+ upstreams = list(upstreams[model])
+ if not upstreams:
+ raise ValueError("No upstream stochastic nodes found for tensor: %s",
+ model)
+ prior_map = dict(ops.get_collection(VI_PRIORS))
+ variational_with_prior = {}
+ for q in upstreams:
+ if require_prior and (q not in prior_map or prior_map[q] is None):
+ raise ValueError("No prior specified for StochasticTensor: %s", q)
+ variational_with_prior[q] = prior_map.get(q)
+
+ if not all(
+ [isinstance(q, st.StochasticTensor) for q in variational_with_prior]):
+ raise TypeError("variationals must be StochasticTensors")
+ if not all([
+ p is None or isinstance(p, distribution.Distribution)
+ for p in variational_with_prior.values()
+ ]):
+ raise TypeError("priors must be Distribution objects")
+
+ return variational_with_prior
diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py
index ee181ef86f..75748a3ce1 100644
--- a/tensorflow/tools/docs/generate.py
+++ b/tensorflow/tools/docs/generate.py
@@ -199,10 +199,8 @@ def extract():
'tfprof',
],
'contrib.bayesflow': [
- 'entropy', 'monte_carlo',
- 'special_math', 'stochastic_gradient_estimators',
- 'stochastic_graph', 'stochastic_tensor',
- 'stochastic_variables', 'variational_inference'
+ 'special_math', 'stochastic_gradient_estimators', 'stochastic_graph',
+ 'stochastic_tensor', 'stochastic_variables'
],
'contrib.distributions': ['bijector'],
'contrib.ffmpeg': ['ffmpeg_ops'],
@@ -215,10 +213,7 @@ def extract():
'select',
'util'
],
- 'contrib.layers': [
- 'feature_column',
- 'summaries'
- ],
+ 'contrib.layers': ['feature_column', 'summaries'],
'contrib.learn': [
'datasets',
'head',