aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-03-07 15:00:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 15:05:03 -0800
commit5e7b3556619a4a6450b588d8b2f173729ffc9203 (patch)
treed2c9f5353c813f34d2e5414d52cbeefe8dd6d276 /tensorflow/contrib/bayesflow
parentfffb7b59f5695b36af4e03c1dd8eadff3fd0024c (diff)
Migrate AIS chain into `tfp.mcmc` and modularize its interface to take a TransitionKernel.
PiperOrigin-RevId: 188239559
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py132
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/hmc.py1
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/hmc_impl.py217
3 files changed, 0 insertions, 350 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
index 819095a060..dabadfc7b6 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
@@ -462,138 +462,6 @@ class HMCTest(test.TestCase):
def testKernelLeavesTargetInvariant3(self):
self._kernel_leaves_target_invariant_wrapper(3)
- def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims,
- sess, feed_dict=None):
- counter = collections.Counter()
-
- def proposal_log_prob(x):
- counter["proposal_calls"] += 1
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
- axis=event_dims)
-
- def target_log_prob(x):
- counter["target_calls"] += 1
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return self._log_gamma_log_prob(x, event_dims)
-
- if feed_dict is None:
- feed_dict = {}
-
- num_steps = 200
-
- _, ais_weights, _ = hmc.sample_annealed_importance_chain(
- proposal_log_prob_fn=proposal_log_prob,
- num_steps=num_steps,
- target_log_prob_fn=target_log_prob,
- step_size=0.5,
- current_state=init,
- num_leapfrog_steps=2,
- seed=45)
-
- # We have three calls because the calculation of `ais_weights` entails
- # another call to the `convex_combined_log_prob_fn`. We could refactor
- # things to avoid this, if needed (eg, b/72994218).
- self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter)
-
- event_shape = array_ops.shape(init)[independent_chain_ndims:]
- event_size = math_ops.reduce_prod(event_shape)
-
- log_true_normalizer = (
- -self._shape_param * math_ops.log(self._rate_param)
- + math_ops.lgamma(self._shape_param))
- log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype)
-
- log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights)
- - np.log(num_steps))
-
- ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer)
- ais_weights_size = array_ops.size(ais_weights)
- standard_error = math_ops.sqrt(
- _reduce_variance(ratio_estimate_true)
- / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype))
-
- [
- ratio_estimate_true_,
- log_true_normalizer_,
- log_estimated_normalizer_,
- standard_error_,
- ais_weights_size_,
- event_size_,
- ] = sess.run([
- ratio_estimate_true,
- log_true_normalizer,
- log_estimated_normalizer,
- standard_error,
- ais_weights_size,
- event_size,
- ], feed_dict)
-
- logging_ops.vlog(1, " log_true_normalizer: {}\n"
- " log_estimated_normalizer: {}\n"
- " ais_weights_size: {}\n"
- " event_size: {}\n".format(
- log_true_normalizer_,
- log_estimated_normalizer_,
- ais_weights_size_,
- event_size_))
- self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_)
-
- def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims):
- """Tests that AIS yields reasonable estimates of normalizers."""
- with self.test_session(graph=ops.Graph()) as sess:
- x_ph = array_ops.placeholder(np.float32, name="x_ph")
- initial_draws = np.random.normal(size=[30, 2, 1])
- self._ais_gets_correct_log_normalizer(
- x_ph,
- independent_chain_ndims,
- sess,
- feed_dict={x_ph: initial_draws})
-
- def testAIS1(self):
- self._ais_gets_correct_log_normalizer_wrapper(1)
-
- def testAIS2(self):
- self._ais_gets_correct_log_normalizer_wrapper(2)
-
- def testAIS3(self):
- self._ais_gets_correct_log_normalizer_wrapper(3)
-
- def testSampleAIChainSeedReproducibleWorksCorrectly(self):
- with self.test_session(graph=ops.Graph()) as sess:
- independent_chain_ndims = 1
- x = np.random.rand(4, 3, 2)
-
- def proposal_log_prob(x):
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
- axis=event_dims)
-
- def target_log_prob(x):
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return self._log_gamma_log_prob(x, event_dims)
-
- ais_kwargs = dict(
- proposal_log_prob_fn=proposal_log_prob,
- num_steps=200,
- target_log_prob_fn=target_log_prob,
- step_size=0.5,
- current_state=x,
- num_leapfrog_steps=2,
- seed=53)
-
- _, ais_weights0, _ = hmc.sample_annealed_importance_chain(
- **ais_kwargs)
-
- _, ais_weights1, _ = hmc.sample_annealed_importance_chain(
- **ais_kwargs)
-
- [ais_weights0_, ais_weights1_] = sess.run([
- ais_weights0, ais_weights1])
-
- self.assertAllClose(ais_weights0_, ais_weights1_,
- atol=1e-5, rtol=1e-5)
-
def testNanRejection(self):
"""Tests that an update that yields NaN potentials gets rejected.
diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py
index 7fd5652c5c..c8a5a195d3 100644
--- a/tensorflow/contrib/bayesflow/python/ops/hmc.py
+++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py
@@ -24,7 +24,6 @@ from tensorflow.python.util import all_util
_allowed_symbols = [
"sample_chain",
- "sample_annealed_importance_chain",
"kernel",
]
diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
index 82693c2b7b..66afcc7497 100644
--- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
@@ -15,7 +15,6 @@
"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
@@sample_chain
-@@sample_annealed_importance_chain
@@kernel
"""
@@ -38,7 +37,6 @@ from tensorflow.python.ops.distributions import util as distributions_util
__all__ = [
"sample_chain",
- "sample_annealed_importance_chain",
"kernel",
]
@@ -330,221 +328,6 @@ def sample_chain(
return functional_ops.scan(**scan_kwargs)
-def sample_annealed_importance_chain(
- proposal_log_prob_fn,
- num_steps,
- target_log_prob_fn,
- current_state,
- step_size,
- num_leapfrog_steps,
- seed=None,
- name=None):
- """Runs annealed importance sampling (AIS) to estimate normalizing constants.
-
- This function uses Hamiltonian Monte Carlo to sample from a series of
- distributions that slowly interpolates between an initial "proposal"
- distribution:
-
- `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`
-
- and the target distribution:
-
- `exp(target_log_prob_fn(x) - target_log_normalizer)`,
-
- accumulating importance weights along the way. The product of these
- importance weights gives an unbiased estimate of the ratio of the
- normalizing constants of the initial distribution and the target
- distribution:
-
- `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.
-
- Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three
- times (although this may be reduced to two times, in the future).
-
- #### Examples:
-
- ##### Estimate the normalizing constant of a log-gamma distribution.
-
- ```python
- tfd = tf.contrib.distributions
-
- # Run 100 AIS chains in parallel
- num_chains = 100
- dims = 20
- dtype = np.float32
-
- proposal = tfd.MultivatiateNormalDiag(
- loc=tf.zeros([dims], dtype=dtype))
-
- target = tfd.TransformedDistribution(
- distribution=tfd.Gamma(concentration=dtype(2),
- rate=dtype(3)),
- bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()),
- event_shape=[dims])
-
- chains_state, ais_weights, kernels_results = (
- hmc.sample_annealed_importance_chain(
- proposal_log_prob_fn=proposal.log_prob,
- num_steps=1000,
- target_log_prob_fn=target.log_prob,
- step_size=0.2,
- current_state=proposal.sample(num_chains),
- num_leapfrog_steps=2))
-
- log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
- - np.log(num_chains))
- log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
- ```
-
- ##### Estimate marginal likelihood of a Bayesian regression model.
-
- ```python
- tfd = tf.contrib.distributions
-
- def make_prior(dims, dtype):
- return tfd.MultivariateNormalDiag(
- loc=tf.zeros(dims, dtype))
-
- def make_likelihood(weights, x):
- return tfd.MultivariateNormalDiag(
- loc=tf.tensordot(weights, x, axes=[[0], [-1]]))
-
- # Run 100 AIS chains in parallel
- num_chains = 100
- dims = 10
- dtype = np.float32
-
- # Make training data.
- x = np.random.randn(num_chains, dims).astype(dtype)
- true_weights = np.random.randn(dims).astype(dtype)
- y = np.dot(x, true_weights) + np.random.randn(num_chains)
-
- # Setup model.
- prior = make_prior(dims, dtype)
- def target_log_prob_fn(weights):
- return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)
-
- proposal = tfd.MultivariateNormalDiag(
- loc=tf.zeros(dims, dtype))
-
- weight_samples, ais_weights, kernel_results = (
- hmc.sample_annealed_importance_chain(
- num_steps=1000,
- proposal_log_prob_fn=proposal.log_prob,
- target_log_prob_fn=target_log_prob_fn
- current_state=tf.zeros([num_chains, dims], dtype),
- step_size=0.1,
- num_leapfrog_steps=2))
- log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
- - np.log(num_chains))
- ```
-
- Args:
- proposal_log_prob_fn: Python callable that returns the log density of the
- initial distribution.
- num_steps: Integer number of Markov chain updates to run. More
- iterations means more expense, but smoother annealing between q
- and p, which in turn means exponentially lower variance for the
- normalizing constant estimator.
- target_log_prob_fn: Python callable which takes an argument like
- `current_state` (or `*current_state` if it's a list) and returns its
- (possibly unnormalized) log-density under the target distribution.
- current_state: `Tensor` or Python `list` of `Tensor`s representing the
- current state(s) of the Markov chain(s). The first `r` dimensions index
- independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
- step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
- for the leapfrog integrator. Must broadcast with the shape of
- `current_state`. Larger step sizes lead to faster progress, but too-large
- step sizes make rejection exponentially more likely. When possible, it's
- often helpful to match per-variable step sizes to the standard deviations
- of the target distribution in each variable.
- num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
- for. Total progress per HMC step is roughly proportional to `step_size *
- num_leapfrog_steps`.
- seed: Python integer to seed the random number generator.
- name: Python `str` name prefixed to Ops created by this function.
- Default value: `None` (i.e., "hmc_sample_annealed_importance_chain").
-
- Returns:
- next_state: `Tensor` or Python list of `Tensor`s representing the
- state(s) of the Markov chain(s) at the final iteration. Has same shape as
- input `current_state`.
- ais_weights: Tensor with the estimated weight(s). Has shape matching
- `target_log_prob_fn(current_state)`.
- kernel_results: `collections.namedtuple` of internal calculations used to
- advance the chain.
- """
- def make_convex_combined_log_prob_fn(iter_):
- def _fn(*args):
- p = proposal_log_prob_fn(*args)
- t = target_log_prob_fn(*args)
- dtype = p.dtype.base_dtype
- beta = (math_ops.cast(iter_ + 1, dtype)
- / math_ops.cast(num_steps, dtype))
- return (1. - beta) * p + beta * t
- return _fn
-
- with ops.name_scope(
- name, "hmc_sample_annealed_importance_chain",
- [num_steps, current_state, step_size, num_leapfrog_steps, seed]):
- with ops.name_scope("initialize"):
- [
- current_state,
- step_size,
- current_log_prob,
- current_grads_log_prob,
- ] = _prepare_args(
- make_convex_combined_log_prob_fn(iter_=0),
- current_state,
- step_size,
- description="convex_combined_log_prob")
- num_steps = ops.convert_to_tensor(
- num_steps,
- dtype=dtypes.int32,
- name="num_steps")
- num_leapfrog_steps = ops.convert_to_tensor(
- num_leapfrog_steps,
- dtype=dtypes.int32,
- name="num_leapfrog_steps")
- def _loop_body(iter_, ais_weights, current_state, kernel_results):
- """Closure which implements `tf.while_loop` body."""
- current_state_parts = (list(current_state)
- if _is_list_like(current_state)
- else [current_state])
- # TODO(b/72994218): Consider refactoring things to avoid this unecessary
- # call.
- ais_weights += ((target_log_prob_fn(*current_state_parts)
- - proposal_log_prob_fn(*current_state_parts))
- / math_ops.cast(num_steps, ais_weights.dtype))
- return [iter_ + 1, ais_weights] + list(kernel(
- make_convex_combined_log_prob_fn(iter_),
- current_state,
- step_size,
- num_leapfrog_steps,
- seed,
- kernel_results.current_target_log_prob,
- kernel_results.current_grads_target_log_prob))
-
- while_loop_kwargs = dict(
- cond=lambda iter_, *args: iter_ < num_steps,
- body=_loop_body,
- loop_vars=[
- np.int32(0), # iter_
- array_ops.zeros_like(current_log_prob), # ais_weights
- current_state,
- _make_dummy_kernel_results(current_state,
- current_log_prob,
- current_grads_log_prob),
- ])
- if seed is not None:
- while_loop_kwargs["parallel_iterations"] = 1
-
- [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop(
- **while_loop_kwargs)[1:] # Lop-off "iter_".
-
- return [current_state, ais_weights, kernel_results]
-
-
def kernel(target_log_prob_fn,
current_state,
step_size,