aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-02-06 14:17:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-06 14:22:15 -0800
commit71248e6f4f79f7d9b6f35854e6bab2caeabfb555 (patch)
tree5ddf0a35fc04df5f7996a38f4bcf4f51ceaad9c0 /tensorflow/contrib/bayesflow
parent497b4fb1440d95161a7e7c577557fa4c101a6b98 (diff)
Automated g4 rollback of changelist 184551259
PiperOrigin-RevId: 184738583
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py831
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/hmc.py11
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/hmc_impl.py1598
3 files changed, 1650 insertions, 790 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
index cbc66b6dc1..51aed6438d 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
@@ -18,30 +18,40 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
import numpy as np
-from scipy import special
from scipy import stats
from tensorflow.contrib.bayesflow.python.ops import hmc
+from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change
+from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import gradients_impl as gradients_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops.distributions import gamma as gamma_lib
+from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.platform import tf_logging as logging_ops
+
+
+def _reduce_variance(x, axis=None, keepdims=False):
+ sample_mean = math_ops.reduce_mean(x, axis, keepdims=True)
+ return math_ops.reduce_mean(
+ math_ops.squared_difference(x, sample_mean), axis, keepdims)
-# TODO(b/66964210): Test float16.
class HMCTest(test.TestCase):
def setUp(self):
self._shape_param = 5.
self._rate_param = 10.
- self._expected_x = (special.digamma(self._shape_param)
- - np.log(self._rate_param))
- self._expected_exp_x = self._shape_param / self._rate_param
random_seed.set_random_seed(10003)
np.random.seed(10003)
@@ -63,63 +73,46 @@ class HMCTest(test.TestCase):
self._rate_param * math_ops.exp(x),
event_dims)
- def _log_gamma_log_prob_grad(self, x, event_dims=()):
- """Computes log-pdf and gradient of a log-gamma random variable.
-
- Args:
- x: Value of the random variable.
- event_dims: Dimensions not to treat as independent. Default is (),
- i.e., all dimensions are independent.
-
- Returns:
- log_prob: The log-pdf up to a normalizing constant.
- grad: The gradient of the log-pdf with respect to x.
- """
- return (math_ops.reduce_sum(self._shape_param * x -
- self._rate_param * math_ops.exp(x),
- event_dims),
- self._shape_param - self._rate_param * math_ops.exp(x))
-
- def _n_event_dims(self, x_shape, event_dims):
- return np.prod([int(x_shape[i]) for i in event_dims])
-
- def _integrator_conserves_energy(self, x, event_dims, sess,
+ def _integrator_conserves_energy(self, x, independent_chain_ndims, sess,
feed_dict=None):
- def potential_and_grad(x):
- log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims)
- return -log_prob, -grad
-
- step_size = array_ops.placeholder(np.float32, [], name='step_size')
- hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps')
+ step_size = array_ops.placeholder(np.float32, [], name="step_size")
+ hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps")
if feed_dict is None:
feed_dict = {}
feed_dict[hmc_lf_steps] = 1000
- m = random_ops.random_normal(array_ops.shape(x))
- potential_0, grad_0 = potential_and_grad(x)
- old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m,
- event_dims)
-
- _, new_m, potential_1, _ = (
- hmc.leapfrog_integrator(step_size, hmc_lf_steps, x,
- m, potential_and_grad, grad_0))
+ event_dims = math_ops.range(independent_chain_ndims,
+ array_ops.rank(x))
- new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
+ m = random_ops.random_normal(array_ops.shape(x))
+ log_prob_0 = self._log_gamma_log_prob(x, event_dims)
+ grad_0 = gradients_ops.gradients(log_prob_0, x)
+ old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims)
+
+ new_m, _, log_prob_1, _ = _leapfrog_integrator(
+ current_momentums=[m],
+ target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims),
+ current_state_parts=[x],
+ step_sizes=[step_size],
+ num_leapfrog_steps=hmc_lf_steps,
+ current_target_log_prob=log_prob_0,
+ current_grads_target_log_prob=grad_0)
+ new_m = new_m[0]
+
+ new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
event_dims)
x_shape = sess.run(x, feed_dict).shape
- n_event_dims = self._n_event_dims(x_shape, event_dims)
- feed_dict[step_size] = 0.1 / n_event_dims
- old_energy_val, new_energy_val = sess.run([old_energy, new_energy],
- feed_dict)
- logging.vlog(1, 'average energy change: {}'.format(
- abs(old_energy_val - new_energy_val).mean()))
-
- self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool),
- abs(old_energy_val - new_energy_val) < 1.)
-
- def _integrator_conserves_energy_wrapper(self, event_dims):
+ event_size = np.prod(x_shape[independent_chain_ndims:])
+ feed_dict[step_size] = 0.1 / event_size
+ old_energy_, new_energy_ = sess.run([old_energy, new_energy],
+ feed_dict)
+ logging_ops.vlog(1, "average energy relative change: {}".format(
+ (1. - new_energy_ / old_energy_).mean()))
+ self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
+
+ def _integrator_conserves_energy_wrapper(self, independent_chain_ndims):
"""Tests the long-term energy conservation of the leapfrog integrator.
The leapfrog integrator is symplectic, so for sufficiently small step
@@ -127,135 +120,218 @@ class HMCTest(test.TestCase):
the energy of the system blowing up or collapsing.
Args:
- event_dims: A tuple of dimensions that should not be treated as
- independent. This allows for multiple chains to be run independently
- in parallel. Default is (), i.e., all dimensions are independent.
+ independent_chain_ndims: Python `int` scalar representing the number of
+ dims associated with independent chains.
"""
- with self.test_session() as sess:
- x_ph = array_ops.placeholder(np.float32, name='x_ph')
-
- feed_dict = {x_ph: np.zeros([50, 10, 2])}
- self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict)
+ with self.test_session(graph=ops.Graph()) as sess:
+ x_ph = array_ops.placeholder(np.float32, name="x_ph")
+ feed_dict = {x_ph: np.random.rand(50, 10, 2)}
+ self._integrator_conserves_energy(x_ph, independent_chain_ndims,
+ sess, feed_dict)
def testIntegratorEnergyConservationNullShape(self):
- self._integrator_conserves_energy_wrapper([])
+ self._integrator_conserves_energy_wrapper(0)
def testIntegratorEnergyConservation1(self):
- self._integrator_conserves_energy_wrapper([1])
+ self._integrator_conserves_energy_wrapper(1)
def testIntegratorEnergyConservation2(self):
- self._integrator_conserves_energy_wrapper([2])
-
- def testIntegratorEnergyConservation12(self):
- self._integrator_conserves_energy_wrapper([1, 2])
-
- def testIntegratorEnergyConservation012(self):
- self._integrator_conserves_energy_wrapper([0, 1, 2])
-
- def _chain_gets_correct_expectations(self, x, event_dims, sess,
- feed_dict=None):
+ self._integrator_conserves_energy_wrapper(2)
+
+ def testIntegratorEnergyConservation3(self):
+ self._integrator_conserves_energy_wrapper(3)
+
+ def testSampleChainSeedReproducibleWorksCorrectly(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ num_results = 10
+ independent_chain_ndims = 1
+
+ def log_gamma_log_prob(x):
+ event_dims = math_ops.range(independent_chain_ndims,
+ array_ops.rank(x))
+ return self._log_gamma_log_prob(x, event_dims)
+
+ kwargs = dict(
+ target_log_prob_fn=log_gamma_log_prob,
+ current_state=np.random.rand(4, 3, 2),
+ step_size=0.1,
+ num_leapfrog_steps=2,
+ num_burnin_steps=150,
+ seed=52,
+ )
+
+ samples0, kernel_results0 = hmc.sample_chain(
+ **dict(list(kwargs.items()) + list(dict(
+ num_results=2 * num_results,
+ num_steps_between_results=0).items())))
+
+ samples1, kernel_results1 = hmc.sample_chain(
+ **dict(list(kwargs.items()) + list(dict(
+ num_results=num_results,
+ num_steps_between_results=1).items())))
+
+ [
+ samples0_,
+ samples1_,
+ target_log_prob0_,
+ target_log_prob1_,
+ ] = sess.run([
+ samples0,
+ samples1,
+ kernel_results0.current_target_log_prob,
+ kernel_results1.current_target_log_prob,
+ ])
+ self.assertAllClose(samples0_[::2], samples1_,
+ atol=1e-5, rtol=1e-5)
+ self.assertAllClose(target_log_prob0_[::2], target_log_prob1_,
+ atol=1e-5, rtol=1e-5)
+
+ def _chain_gets_correct_expectations(self, x, independent_chain_ndims,
+ sess, feed_dict=None):
+ counter = collections.Counter()
def log_gamma_log_prob(x):
+ counter["target_calls"] += 1
+ event_dims = math_ops.range(independent_chain_ndims,
+ array_ops.rank(x))
return self._log_gamma_log_prob(x, event_dims)
- step_size = array_ops.placeholder(np.float32, [], name='step_size')
- hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps')
- hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps')
+ num_results = array_ops.placeholder(
+ np.int32, [], name="num_results")
+ step_size = array_ops.placeholder(
+ np.float32, [], name="step_size")
+ num_leapfrog_steps = array_ops.placeholder(
+ np.int32, [], name="num_leapfrog_steps")
if feed_dict is None:
feed_dict = {}
- feed_dict.update({step_size: 0.1,
- hmc_lf_steps: 2,
- hmc_n_steps: 300})
-
- sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps],
- step_size,
- hmc_lf_steps,
- x, log_gamma_log_prob,
- event_dims)
-
- acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain],
- feed_dict)
- samples = samples[feed_dict[hmc_n_steps] // 2:]
- expected_x_est = samples.mean()
- expected_exp_x_est = np.exp(samples).mean()
-
- logging.vlog(1, 'True E[x, exp(x)]: {}\t{}'.format(
- self._expected_x, self._expected_exp_x))
- logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format(
- expected_x_est, expected_exp_x_est))
- self.assertNear(expected_x_est, self._expected_x, 2e-2)
- self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2)
- self.assertTrue((acceptance_probs > 0.5).all())
- self.assertTrue((acceptance_probs <= 1.0).all())
-
- def _chain_gets_correct_expectations_wrapper(self, event_dims):
- with self.test_session() as sess:
- x_ph = array_ops.placeholder(np.float32, name='x_ph')
-
- feed_dict = {x_ph: np.zeros([50, 10, 2])}
- self._chain_gets_correct_expectations(x_ph, event_dims, sess,
- feed_dict)
+ feed_dict.update({num_results: 150,
+ step_size: 0.05,
+ num_leapfrog_steps: 2})
+
+ samples, kernel_results = hmc.sample_chain(
+ num_results=num_results,
+ target_log_prob_fn=log_gamma_log_prob,
+ current_state=x,
+ step_size=step_size,
+ num_leapfrog_steps=num_leapfrog_steps,
+ num_burnin_steps=150,
+ seed=42)
+
+ self.assertAllEqual(dict(target_calls=2), counter)
+
+ expected_x = (math_ops.digamma(self._shape_param)
+ - np.log(self._rate_param))
+
+ expected_exp_x = self._shape_param / self._rate_param
+
+ acceptance_probs_, samples_, expected_x_ = sess.run(
+ [kernel_results.acceptance_probs, samples, expected_x],
+ feed_dict)
+
+ actual_x = samples_.mean()
+ actual_exp_x = np.exp(samples_).mean()
+
+ logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format(
+ expected_x_, expected_exp_x))
+ logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format(
+ actual_x, actual_exp_x))
+ self.assertNear(actual_x, expected_x_, 2e-2)
+ self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
+ self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool),
+ acceptance_probs_ > 0.5)
+ self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool),
+ acceptance_probs_ <= 1.)
+
+ def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims):
+ with self.test_session(graph=ops.Graph()) as sess:
+ x_ph = array_ops.placeholder(np.float32, name="x_ph")
+ feed_dict = {x_ph: np.random.rand(50, 10, 2)}
+ self._chain_gets_correct_expectations(x_ph, independent_chain_ndims,
+ sess, feed_dict)
def testHMCChainExpectationsNullShape(self):
- self._chain_gets_correct_expectations_wrapper([])
+ self._chain_gets_correct_expectations_wrapper(0)
def testHMCChainExpectations1(self):
- self._chain_gets_correct_expectations_wrapper([1])
+ self._chain_gets_correct_expectations_wrapper(1)
def testHMCChainExpectations2(self):
- self._chain_gets_correct_expectations_wrapper([2])
-
- def testHMCChainExpectations12(self):
- self._chain_gets_correct_expectations_wrapper([1, 2])
+ self._chain_gets_correct_expectations_wrapper(2)
- def _kernel_leaves_target_invariant(self, initial_draws, event_dims,
+ def _kernel_leaves_target_invariant(self, initial_draws,
+ independent_chain_ndims,
sess, feed_dict=None):
def log_gamma_log_prob(x):
+ event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
return self._log_gamma_log_prob(x, event_dims)
def fake_log_prob(x):
"""Cooled version of the target distribution."""
return 1.1 * log_gamma_log_prob(x)
- step_size = array_ops.placeholder(np.float32, [], name='step_size')
+ step_size = array_ops.placeholder(np.float32, [], name="step_size")
if feed_dict is None:
feed_dict = {}
feed_dict[step_size] = 0.4
- sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws,
- log_gamma_log_prob, event_dims)
- bad_sample, bad_acceptance_probs, _, _ = hmc.kernel(
- step_size, 5, initial_draws, fake_log_prob, event_dims)
- (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val,
- updated_draws_val, fake_draws_val) = sess.run([acceptance_probs,
- bad_acceptance_probs,
- initial_draws, sample,
- bad_sample], feed_dict)
+ sample, kernel_results = hmc.kernel(
+ target_log_prob_fn=log_gamma_log_prob,
+ current_state=initial_draws,
+ step_size=step_size,
+ num_leapfrog_steps=5,
+ seed=43)
+
+ bad_sample, bad_kernel_results = hmc.kernel(
+ target_log_prob_fn=fake_log_prob,
+ current_state=initial_draws,
+ step_size=step_size,
+ num_leapfrog_steps=5,
+ seed=44)
+
+ [
+ acceptance_probs_,
+ bad_acceptance_probs_,
+ initial_draws_,
+ updated_draws_,
+ fake_draws_,
+ ] = sess.run([
+ kernel_results.acceptance_probs,
+ bad_kernel_results.acceptance_probs,
+ initial_draws,
+ sample,
+ bad_sample,
+ ], feed_dict)
+
# Confirm step size is small enough that we usually accept.
- self.assertGreater(acceptance_probs_val.mean(), 0.5)
- self.assertGreater(bad_acceptance_probs_val.mean(), 0.5)
+ self.assertGreater(acceptance_probs_.mean(), 0.5)
+ self.assertGreater(bad_acceptance_probs_.mean(), 0.5)
+
# Confirm step size is large enough that we sometimes reject.
- self.assertLess(acceptance_probs_val.mean(), 0.99)
- self.assertLess(bad_acceptance_probs_val.mean(), 0.99)
- _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(),
- updated_draws_val.flatten())
- _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(),
- fake_draws_val.flatten())
- logging.vlog(1, 'acceptance rate for true target: {}'.format(
- acceptance_probs_val.mean()))
- logging.vlog(1, 'acceptance rate for fake target: {}'.format(
- bad_acceptance_probs_val.mean()))
- logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true))
- logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake))
+ self.assertLess(acceptance_probs_.mean(), 0.99)
+ self.assertLess(bad_acceptance_probs_.mean(), 0.99)
+
+ _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(),
+ updated_draws_.flatten())
+ _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(),
+ fake_draws_.flatten())
+
+ logging_ops.vlog(1, "acceptance rate for true target: {}".format(
+ acceptance_probs_.mean()))
+ logging_ops.vlog(1, "acceptance rate for fake target: {}".format(
+ bad_acceptance_probs_.mean()))
+ logging_ops.vlog(1, "K-S p-value for true target: {}".format(
+ ks_p_value_true))
+ logging_ops.vlog(1, "K-S p-value for fake target: {}".format(
+ ks_p_value_fake))
# Make sure that the MCMC update hasn't changed the empirical CDF much.
self.assertGreater(ks_p_value_true, 1e-3)
# Confirm that targeting the wrong distribution does
# significantly change the empirical CDF.
self.assertLess(ks_p_value_fake, 1e-6)
- def _kernel_leaves_target_invariant_wrapper(self, event_dims):
+ def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims):
"""Tests that the kernel leaves the target distribution invariant.
Draws some independent samples from the target distribution,
@@ -267,86 +343,160 @@ class HMCTest(test.TestCase):
does change the target distribution. (And that we can detect that.)
Args:
- event_dims: A tuple of dimensions that should not be treated as
- independent. This allows for multiple chains to be run independently
- in parallel. Default is (), i.e., all dimensions are independent.
+ independent_chain_ndims: Python `int` scalar representing the number of
+ dims associated with independent chains.
"""
- with self.test_session() as sess:
+ with self.test_session(graph=ops.Graph()) as sess:
initial_draws = np.log(np.random.gamma(self._shape_param,
size=[50000, 2, 2]))
initial_draws -= np.log(self._rate_param)
- x_ph = array_ops.placeholder(np.float32, name='x_ph')
+ x_ph = array_ops.placeholder(np.float32, name="x_ph")
feed_dict = {x_ph: initial_draws}
- self._kernel_leaves_target_invariant(x_ph, event_dims, sess,
- feed_dict)
-
- def testKernelLeavesTargetInvariantNullShape(self):
- self._kernel_leaves_target_invariant_wrapper([])
+ self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims,
+ sess, feed_dict)
def testKernelLeavesTargetInvariant1(self):
- self._kernel_leaves_target_invariant_wrapper([1])
+ self._kernel_leaves_target_invariant_wrapper(1)
def testKernelLeavesTargetInvariant2(self):
- self._kernel_leaves_target_invariant_wrapper([2])
+ self._kernel_leaves_target_invariant_wrapper(2)
- def testKernelLeavesTargetInvariant12(self):
- self._kernel_leaves_target_invariant_wrapper([1, 2])
+ def testKernelLeavesTargetInvariant3(self):
+ self._kernel_leaves_target_invariant_wrapper(3)
+
+ def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims,
+ sess, feed_dict=None):
+ counter = collections.Counter()
- def _ais_gets_correct_log_normalizer(self, init, event_dims, sess,
- feed_dict=None):
def proposal_log_prob(x):
- return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi),
- event_dims)
+ counter["proposal_calls"] += 1
+ event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
+ return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
+ axis=event_dims)
def target_log_prob(x):
+ counter["target_calls"] += 1
+ event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
return self._log_gamma_log_prob(x, event_dims)
if feed_dict is None:
feed_dict = {}
- w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob,
- proposal_log_prob, event_dims)
+ num_steps = 200
+
+ _, ais_weights, _ = hmc.sample_annealed_importance_chain(
+ proposal_log_prob_fn=proposal_log_prob,
+ num_steps=num_steps,
+ target_log_prob_fn=target_log_prob,
+ step_size=0.5,
+ current_state=init,
+ num_leapfrog_steps=2,
+ seed=45)
+
+ # We have three calls because the calculation of `ais_weights` entails
+ # another call to the `convex_combined_log_prob_fn`. We could refactor
+ # things to avoid this, if needed (eg, b/72994218).
+ self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter)
+
+ event_shape = array_ops.shape(init)[independent_chain_ndims:]
+ event_size = math_ops.reduce_prod(event_shape)
+
+ log_true_normalizer = (
+ -self._shape_param * math_ops.log(self._rate_param)
+ + math_ops.lgamma(self._shape_param))
+ log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype)
+
+ log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights)
+ - np.log(num_steps))
+
+ ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer)
+ ais_weights_size = array_ops.size(ais_weights)
+ standard_error = math_ops.sqrt(
+ _reduce_variance(ratio_estimate_true)
+ / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype))
+
+ [
+ ratio_estimate_true_,
+ log_true_normalizer_,
+ log_estimated_normalizer_,
+ standard_error_,
+ ais_weights_size_,
+ event_size_,
+ ] = sess.run([
+ ratio_estimate_true,
+ log_true_normalizer,
+ log_estimated_normalizer,
+ standard_error,
+ ais_weights_size,
+ event_size,
+ ], feed_dict)
+
+ logging_ops.vlog(1, " log_true_normalizer: {}\n"
+ " log_estimated_normalizer: {}\n"
+ " ais_weights_size: {}\n"
+ " event_size: {}\n".format(
+ log_true_normalizer_,
+ log_estimated_normalizer_,
+ ais_weights_size_,
+ event_size_))
+ self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_)
+
+ def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims):
+ """Tests that AIS yields reasonable estimates of normalizers."""
+ with self.test_session(graph=ops.Graph()) as sess:
+ x_ph = array_ops.placeholder(np.float32, name="x_ph")
+ initial_draws = np.random.normal(size=[30, 2, 1])
+ self._ais_gets_correct_log_normalizer(
+ x_ph,
+ independent_chain_ndims,
+ sess,
+ feed_dict={x_ph: initial_draws})
- w_val = sess.run(w, feed_dict)
- init_shape = sess.run(init, feed_dict).shape
- normalizer_multiplier = np.prod([init_shape[i] for i in event_dims])
+ def testAIS1(self):
+ self._ais_gets_correct_log_normalizer_wrapper(1)
- true_normalizer = -self._shape_param * np.log(self._rate_param)
- true_normalizer += special.gammaln(self._shape_param)
- true_normalizer *= normalizer_multiplier
+ def testAIS2(self):
+ self._ais_gets_correct_log_normalizer_wrapper(2)
- n_weights = np.prod(w_val.shape)
- normalized_w = np.exp(w_val - true_normalizer)
- standard_error = np.std(normalized_w) / np.sqrt(n_weights)
- logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format(
- true_normalizer, np.log(normalized_w.mean()) + true_normalizer,
- n_weights))
- self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error)
+ def testAIS3(self):
+ self._ais_gets_correct_log_normalizer_wrapper(3)
- def _ais_gets_correct_log_normalizer_wrapper(self, event_dims):
- """Tests that AIS yields reasonable estimates of normalizers."""
- with self.test_session() as sess:
- x_ph = array_ops.placeholder(np.float32, name='x_ph')
+ def testSampleAIChainSeedReproducibleWorksCorrectly(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ independent_chain_ndims = 1
+ x = np.random.rand(4, 3, 2)
- initial_draws = np.random.normal(size=[30, 2, 1])
- feed_dict = {x_ph: initial_draws}
+ def proposal_log_prob(x):
+ event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
+ return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
+ axis=event_dims)
- self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess,
- feed_dict)
+ def target_log_prob(x):
+ event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
+ return self._log_gamma_log_prob(x, event_dims)
- def testAISNullShape(self):
- self._ais_gets_correct_log_normalizer_wrapper([])
+ ais_kwargs = dict(
+ proposal_log_prob_fn=proposal_log_prob,
+ num_steps=200,
+ target_log_prob_fn=target_log_prob,
+ step_size=0.5,
+ current_state=x,
+ num_leapfrog_steps=2,
+ seed=53)
- def testAIS1(self):
- self._ais_gets_correct_log_normalizer_wrapper([1])
+ _, ais_weights0, _ = hmc.sample_annealed_importance_chain(
+ **ais_kwargs)
- def testAIS2(self):
- self._ais_gets_correct_log_normalizer_wrapper([2])
+ _, ais_weights1, _ = hmc.sample_annealed_importance_chain(
+ **ais_kwargs)
- def testAIS12(self):
- self._ais_gets_correct_log_normalizer_wrapper([1, 2])
+ [ais_weights0_, ais_weights1_] = sess.run([
+ ais_weights0, ais_weights1])
+
+ self.assertAllClose(ais_weights0_, ais_weights1_,
+ atol=1e-5, rtol=1e-5)
def testNanRejection(self):
"""Tests that an update that yields NaN potentials gets rejected.
@@ -359,86 +509,263 @@ class HMCTest(test.TestCase):
"""
def _unbounded_exponential_log_prob(x):
"""An exponential distribution with log-likelihood NaN for x < 0."""
- per_element_potentials = array_ops.where(x < 0,
- np.nan * array_ops.ones_like(x),
- -x)
+ per_element_potentials = array_ops.where(
+ x < 0.,
+ array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)),
+ -x)
return math_ops.reduce_sum(per_element_potentials)
- with self.test_session() as sess:
+ with self.test_session(graph=ops.Graph()) as sess:
initial_x = math_ops.linspace(0.01, 5, 10)
- updated_x, acceptance_probs, _, _ = hmc.kernel(
- 2., 5, initial_x, _unbounded_exponential_log_prob, [0])
- initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
- [initial_x, updated_x, acceptance_probs])
-
- logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
- logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
- logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
-
- self.assertAllEqual(initial_x_val, updated_x_val)
- self.assertEqual(acceptance_probs_val, 0.)
+ updated_x, kernel_results = hmc.kernel(
+ target_log_prob_fn=_unbounded_exponential_log_prob,
+ current_state=initial_x,
+ step_size=2.,
+ num_leapfrog_steps=5,
+ seed=46)
+ initial_x_, updated_x_, acceptance_probs_ = sess.run(
+ [initial_x, updated_x, kernel_results.acceptance_probs])
+
+ logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
+ logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
+ logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
+
+ self.assertAllEqual(initial_x_, updated_x_)
+ self.assertEqual(acceptance_probs_, 0.)
def testNanFromGradsDontPropagate(self):
"""Test that update with NaN gradients does not cause NaN in results."""
def _nan_log_prob_with_nan_gradient(x):
return np.nan * math_ops.reduce_sum(x)
- with self.test_session() as sess:
+ with self.test_session(graph=ops.Graph()) as sess:
initial_x = math_ops.linspace(0.01, 5, 10)
- updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel(
- 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0])
- initial_x_val, updated_x_val, acceptance_probs_val = sess.run(
- [initial_x, updated_x, acceptance_probs])
-
- logging.vlog(1, 'initial_x = {}'.format(initial_x_val))
- logging.vlog(1, 'updated_x = {}'.format(updated_x_val))
- logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val))
-
- self.assertAllEqual(initial_x_val, updated_x_val)
- self.assertEqual(acceptance_probs_val, 0.)
+ updated_x, kernel_results = hmc.kernel(
+ target_log_prob_fn=_nan_log_prob_with_nan_gradient,
+ current_state=initial_x,
+ step_size=2.,
+ num_leapfrog_steps=5,
+ seed=47)
+ initial_x_, updated_x_, acceptance_probs_ = sess.run(
+ [initial_x, updated_x, kernel_results.acceptance_probs])
+
+ logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
+ logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
+ logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
+
+ self.assertAllEqual(initial_x_, updated_x_)
+ self.assertEqual(acceptance_probs_, 0.)
self.assertAllFinite(
- gradients_impl.gradients(updated_x, initial_x)[0].eval())
- self.assertTrue(
- gradients_impl.gradients(new_grad, initial_x)[0] is None)
+ gradients_ops.gradients(updated_x, initial_x)[0].eval())
+ self.assertAllEqual([True], [g is None for g in gradients_ops.gradients(
+ kernel_results.proposed_grads_target_log_prob, initial_x)])
+ self.assertAllEqual([False], [g is None for g in gradients_ops.gradients(
+ kernel_results.proposed_grads_target_log_prob,
+ kernel_results.proposed_state)])
# Gradients of the acceptance probs and new log prob are not finite.
- _ = new_log_prob # Prevent unused arg error.
# self.assertAllFinite(
- # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval())
+ # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval())
# self.assertAllFinite(
- # gradients_impl.gradients(new_log_prob, initial_x)[0].eval())
+ # gradients_ops.gradients(new_log_prob, initial_x)[0].eval())
+
+ def _testChainWorksDtype(self, dtype):
+ with self.test_session(graph=ops.Graph()) as sess:
+ states, kernel_results = hmc.sample_chain(
+ num_results=10,
+ target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1),
+ current_state=np.zeros(5).astype(dtype),
+ step_size=0.01,
+ num_leapfrog_steps=10,
+ seed=48)
+ states_, acceptance_probs_ = sess.run(
+ [states, kernel_results.acceptance_probs])
+ self.assertEqual(dtype, states_.dtype)
+ self.assertEqual(dtype, acceptance_probs_.dtype)
def testChainWorksIn64Bit(self):
- def log_prob(x):
- return - math_ops.reduce_sum(x * x, axis=-1)
- states, acceptance_probs = hmc.chain(
- n_iterations=10,
- step_size=np.float64(0.01),
- n_leapfrog_steps=10,
- initial_x=np.zeros(5).astype(np.float64),
- target_log_prob_fn=log_prob,
- event_dims=[-1])
- with self.test_session() as sess:
- states_, acceptance_probs_ = sess.run([states, acceptance_probs])
- self.assertEqual(np.float64, states_.dtype)
- self.assertEqual(np.float64, acceptance_probs_.dtype)
+ self._testChainWorksDtype(np.float64)
def testChainWorksIn16Bit(self):
- def log_prob(x):
- return - math_ops.reduce_sum(x * x, axis=-1)
- states, acceptance_probs = hmc.chain(
- n_iterations=10,
- step_size=np.float16(0.01),
- n_leapfrog_steps=10,
- initial_x=np.zeros(5).astype(np.float16),
- target_log_prob_fn=log_prob,
- event_dims=[-1])
- with self.test_session() as sess:
- states_, acceptance_probs_ = sess.run([states, acceptance_probs])
- self.assertEqual(np.float16, states_.dtype)
- self.assertEqual(np.float16, acceptance_probs_.dtype)
-
-
-if __name__ == '__main__':
+ self._testChainWorksDtype(np.float16)
+
+ def testChainWorksCorrelatedMultivariate(self):
+ dtype = np.float32
+ true_mean = dtype([0, 0])
+ true_cov = dtype([[1, 0.5],
+ [0.5, 1]])
+ num_results = 2000
+ counter = collections.Counter()
+ with self.test_session(graph=ops.Graph()) as sess:
+ def target_log_prob(x, y):
+ counter["target_calls"] += 1
+ # Corresponds to unnormalized MVN.
+ # z = matmul(inv(chol(true_cov)), [x, y] - true_mean)
+ z = array_ops.stack([x, y], axis=-1) - true_mean
+ z = array_ops.squeeze(
+ gen_linalg_ops.matrix_triangular_solve(
+ np.linalg.cholesky(true_cov),
+ z[..., array_ops.newaxis]),
+ axis=-1)
+ return -0.5 * math_ops.reduce_sum(z**2., axis=-1)
+ states, _ = hmc.sample_chain(
+ num_results=num_results,
+ target_log_prob_fn=target_log_prob,
+ current_state=[dtype(-2), dtype(2)],
+ step_size=[0.5, 0.5],
+ num_leapfrog_steps=2,
+ num_burnin_steps=200,
+ num_steps_between_results=1,
+ seed=54)
+ self.assertAllEqual(dict(target_calls=2), counter)
+ states = array_ops.stack(states, axis=-1)
+ self.assertEqual(num_results, states.shape[0].value)
+ sample_mean = math_ops.reduce_mean(states, axis=0)
+ x = states - sample_mean
+ sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results)
+ [sample_mean_, sample_cov_] = sess.run([
+ sample_mean, sample_cov])
+ self.assertAllClose(true_mean, sample_mean_,
+ atol=0.05, rtol=0.)
+ self.assertAllClose(true_cov, sample_cov_,
+ atol=0., rtol=0.1)
+
+
+class _EnergyComputationTest(object):
+
+ def testHandlesNanFromPotential(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ x = [1, np.inf, -np.inf, np.nan]
+ target_log_prob, proposed_target_log_prob = [
+ self.dtype(x.flatten()) for x in np.meshgrid(x, x)]
+ num_chains = len(target_log_prob)
+ dummy_momentums = [-1, 1]
+ momentums = [self.dtype([dummy_momentums] * num_chains)]
+ proposed_momentums = [self.dtype([dummy_momentums] * num_chains)]
+
+ target_log_prob = ops.convert_to_tensor(target_log_prob)
+ momentums = [ops.convert_to_tensor(momentums[0])]
+ proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
+ proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
+
+ energy = _compute_energy_change(
+ target_log_prob,
+ momentums,
+ proposed_target_log_prob,
+ proposed_momentums,
+ independent_chain_ndims=1)
+ grads = gradients_ops.gradients(energy, momentums)
+
+ [actual_energy, grads_] = sess.run([energy, grads])
+
+ # Ensure energy is `inf` (note: that's positive inf) in weird cases and
+ # finite otherwise.
+ expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
+ self.assertAllEqual(expected_energy, actual_energy)
+
+ # Ensure gradient is finite.
+ self.assertAllEqual(np.ones_like(grads_).astype(np.bool),
+ np.isfinite(grads_))
+
+ def testHandlesNanFromKinetic(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ x = [1, np.inf, -np.inf, np.nan]
+ momentums, proposed_momentums = [
+ [np.reshape(self.dtype(x), [-1, 1])]
+ for x in np.meshgrid(x, x)]
+ num_chains = len(momentums[0])
+ target_log_prob = np.ones(num_chains, self.dtype)
+ proposed_target_log_prob = np.ones(num_chains, self.dtype)
+
+ target_log_prob = ops.convert_to_tensor(target_log_prob)
+ momentums = [ops.convert_to_tensor(momentums[0])]
+ proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
+ proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
+
+ energy = _compute_energy_change(
+ target_log_prob,
+ momentums,
+ proposed_target_log_prob,
+ proposed_momentums,
+ independent_chain_ndims=1)
+ grads = gradients_ops.gradients(energy, momentums)
+
+ [actual_energy, grads_] = sess.run([energy, grads])
+
+ # Ensure energy is `inf` (note: that's positive inf) in weird cases and
+ # finite otherwise.
+ expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
+ self.assertAllEqual(expected_energy, actual_energy)
+
+ # Ensure gradient is finite.
+ g = grads_[0].reshape([len(x), len(x)])[:, 0]
+ self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g))
+
+ # The remaining gradients are nan because the momentum was itself nan or
+ # inf.
+ g = grads_[0].reshape([len(x), len(x)])[:, 1:]
+ self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))
+
+
+class EnergyComputationTest16(test.TestCase, _EnergyComputationTest):
+ dtype = np.float16
+
+
+class EnergyComputationTest32(test.TestCase, _EnergyComputationTest):
+ dtype = np.float32
+
+
+class EnergyComputationTest64(test.TestCase, _EnergyComputationTest):
+ dtype = np.float64
+
+
+class _HMCHandlesLists(object):
+
+ def testStateParts(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
+ dist_y = independent_lib.Independent(
+ gamma_lib.Gamma(concentration=self.dtype([1, 2]),
+ rate=self.dtype([0.5, 0.75])),
+ reinterpreted_batch_ndims=1)
+ def target_log_prob(x, y):
+ return dist_x.log_prob(x) + dist_y.log_prob(y)
+ x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
+ samples, _ = hmc.sample_chain(
+ num_results=int(2e3),
+ target_log_prob_fn=target_log_prob,
+ current_state=x0,
+ step_size=0.85,
+ num_leapfrog_steps=3,
+ num_burnin_steps=int(250),
+ seed=49)
+ actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
+ actual_vars = [_reduce_variance(s, axis=0) for s in samples]
+ expected_means = [dist_x.mean(), dist_y.mean()]
+ expected_vars = [dist_x.variance(), dist_y.variance()]
+ [
+ actual_means_,
+ actual_vars_,
+ expected_means_,
+ expected_vars_,
+ ] = sess.run([
+ actual_means,
+ actual_vars,
+ expected_means,
+ expected_vars,
+ ])
+ self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
+ self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25)
+
+
+class HMCHandlesLists32(_HMCHandlesLists, test.TestCase):
+ dtype = np.float32
+
+
+class HMCHandlesLists64(_HMCHandlesLists, test.TestCase):
+ dtype = np.float64
+
+
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py
index 977d42fc16..7fd5652c5c 100644
--- a/tensorflow/contrib/bayesflow/python/ops/hmc.py
+++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
-"""
+"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm."""
from __future__ import absolute_import
from __future__ import division
@@ -24,11 +23,9 @@ from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disabl
from tensorflow.python.util import all_util
_allowed_symbols = [
- 'chain',
- 'kernel',
- 'leapfrog_integrator',
- 'leapfrog_step',
- 'ais_chain'
+ "sample_chain",
+ "sample_annealed_importance_chain",
+ "kernel",
]
all_util.remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
index 5685a942e9..f724910c59 100644
--- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
@@ -14,17 +14,16 @@
# ==============================================================================
"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
-@@chain
-@@update
-@@leapfrog_integrator
-@@leapfrog_step
-@@ais_chain
+@@sample_chain
+@@sample_annealed_importance_chain
+@@kernel
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import numpy as np
from tensorflow.python.framework import dtypes
@@ -32,168 +31,326 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_impl as gradients_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.ops.distributions import util as distributions_util
__all__ = [
- 'chain',
- 'kernel',
- 'leapfrog_integrator',
- 'leapfrog_step',
- 'ais_chain'
+ "sample_chain",
+ "sample_annealed_importance_chain",
+ "kernel",
]
-def _make_potential_and_grad(target_log_prob_fn):
- def potential_and_grad(x):
- log_prob_result = -target_log_prob_fn(x)
- grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result),
- x)[0]
- return log_prob_result, grad_result
- return potential_and_grad
-
-
-def chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
- target_log_prob_fn, event_dims=(), name=None):
+KernelResults = collections.namedtuple(
+ "KernelResults",
+ [
+ "acceptance_probs",
+ "current_grads_target_log_prob", # "Current result" means "accepted".
+ "current_target_log_prob", # "Current result" means "accepted".
+ "energy_change",
+ "is_accepted",
+ "proposed_grads_target_log_prob",
+ "proposed_state",
+ "proposed_target_log_prob",
+ "random_positive",
+ ])
+
+
+def _make_dummy_kernel_results(
+ dummy_state,
+ dummy_target_log_prob,
+ dummy_grads_target_log_prob):
+ return KernelResults(
+ acceptance_probs=dummy_target_log_prob,
+ current_grads_target_log_prob=dummy_grads_target_log_prob,
+ current_target_log_prob=dummy_target_log_prob,
+ energy_change=dummy_target_log_prob,
+ is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool),
+ proposed_grads_target_log_prob=dummy_grads_target_log_prob,
+ proposed_state=dummy_state,
+ proposed_target_log_prob=dummy_target_log_prob,
+ random_positive=dummy_target_log_prob,
+ )
+
+
+def sample_chain(
+ num_results,
+ target_log_prob_fn,
+ current_state,
+ step_size,
+ num_leapfrog_steps,
+ num_burnin_steps=0,
+ num_steps_between_results=0,
+ seed=None,
+ current_target_log_prob=None,
+ current_grads_target_log_prob=None,
+ name=None):
"""Runs multiple iterations of one or more Hamiltonian Monte Carlo chains.
- Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
- algorithm that takes a series of gradient-informed steps to produce
- a Metropolis proposal. This function samples from an HMC Markov
- chain whose initial state is `initial_x` and whose stationary
- distribution has log-density `target_log_prob_fn()`.
-
- This function can update multiple chains in parallel. It assumes
- that all dimensions of `initial_x` not specified in `event_dims` are
- independent, and should therefore be updated independently. The
- output of `target_log_prob_fn()` should sum log-probabilities across
- all event dimensions. Slices along dimensions not in `event_dims`
- may have different target distributions; this is up to
+ Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm
+ that takes a series of gradient-informed steps to produce a Metropolis
+ proposal. This function samples from an HMC Markov chain at `current_state`
+ and whose stationary distribution has log-unnormalized-density
`target_log_prob_fn()`.
- This function basically just wraps `hmc.kernel()` in a tf.scan() loop.
+ This function samples from multiple chains in parallel. It assumes that the
+ the leftmost dimensions of (each) `current_state` (part) index an independent
+ chain. The function `target_log_prob_fn()` sums log-probabilities across
+ event dimensions (i.e., current state (part) rightmost dimensions). Each
+ element of the output of `target_log_prob_fn()` represents the (possibly
+ unnormalized) log-probability of the joint distribution over (all) the current
+ state (parts).
- Args:
- n_iterations: Integer number of Markov chain updates to run.
- step_size: Scalar step size or array of step sizes for the
- leapfrog integrator. Broadcasts to the shape of
- `initial_x`. Larger step sizes lead to faster progress, but
- too-large step sizes make rejection exponentially more likely.
- When possible, it's often helpful to match per-variable step
- sizes to the standard deviations of the target distribution in
- each variable.
- n_leapfrog_steps: Integer number of steps to run the leapfrog
- integrator for. Total progress per HMC step is roughly
- proportional to step_size * n_leapfrog_steps.
- initial_x: Tensor of initial state(s) of the Markov chain(s).
- target_log_prob_fn: Python callable which takes an argument like `initial_x`
- and returns its (possibly unnormalized) log-density under the target
- distribution.
- event_dims: List of dimensions that should not be treated as
- independent. This allows for multiple chains to be run independently
- in parallel. Default is (), i.e., all dimensions are independent.
- name: Python `str` name prefixed to Ops created by this function.
+ The `current_state` can be represented as a single `Tensor` or a `list` of
+ `Tensors` which collectively represent the current state. When specifying a
+ `list`, one must also specify a list of `step_size`s.
- Returns:
- acceptance_probs: Tensor with the acceptance probabilities for each
- iteration. Has shape matching `target_log_prob_fn(initial_x)`.
- chain_states: Tensor with the state of the Markov chain at each iteration.
- Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`.
+ Note: `target_log_prob_fn` is called exactly twice.
+
+ Only one out of every `num_steps_between_samples + 1` steps is included in the
+ returned results. This "thinning" comes at a cost of reduced statistical
+ power, while reducing memory requirements and autocorrelation. For more
+ discussion see [1].
+
+ [1]: "Statistically efficient thinning of a Markov chain sampler."
+ Art B. Owen. April 2017.
+ http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
#### Examples:
- ```python
- # Sampling from a standard normal (note `log_joint()` is unnormalized):
- def log_joint(x):
- return tf.reduce_sum(-0.5 * tf.square(x))
- chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint,
- event_dims=[0])
- # Discard first half of chain as warmup/burn-in
- warmed_up = chain[500:]
- mean_est = tf.reduce_mean(warmed_up, 0)
- var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
- ```
+ ##### Sample from a diagonal-variance Gaussian.
```python
- # Sampling from a diagonal-variance Gaussian:
- variances = tf.linspace(1., 3., 10)
- def log_joint(x):
- return tf.reduce_sum(-0.5 / variances * tf.square(x))
- chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint,
- event_dims=[0])
- # Discard first half of chain as warmup/burn-in
- warmed_up = chain[500:]
- mean_est = tf.reduce_mean(warmed_up, 0)
- var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
+ tfd = tf.contrib.distributions
+
+ def make_likelihood(true_variances):
+ return tfd.MultivariateNormalDiag(
+ scale_diag=tf.sqrt(true_variances))
+
+ dims = 10
+ dtype = np.float32
+ true_variances = tf.linspace(dtype(1), dtype(3), dims)
+ likelihood = make_likelihood(true_variances)
+
+ states, kernel_results = hmc.sample_chain(
+ num_results=1000,
+ target_log_prob_fn=likelihood.log_prob,
+ current_state=tf.zeros(dims),
+ step_size=0.5,
+ num_leapfrog_steps=2,
+ num_burnin_steps=500)
+
+ # Compute sample stats.
+ sample_mean = tf.reduce_mean(states, axis=0)
+ sample_var = tf.reduce_mean(
+ tf.squared_difference(states, sample_mean),
+ axis=0)
```
- ```python
- # Sampling from factor-analysis posteriors with known factors W:
- # mu[i, j] ~ Normal(0, 1)
- # x[i] ~ Normal(matmul(mu[i], W), I)
- def log_joint(mu, x, W):
- prior = -0.5 * tf.reduce_sum(tf.square(mu), 1)
- x_mean = tf.matmul(mu, W)
- likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1)
- return prior + likelihood
- chain, acceptance_probs = hmc.chain(1000, 0.1, 2,
- tf.zeros([x.shape[0], W.shape[0]]),
- lambda mu: log_joint(mu, x, W),
- event_dims=[1])
- # Discard first half of chain as warmup/burn-in
- warmed_up = chain[500:]
- mean_est = tf.reduce_mean(warmed_up, 0)
- var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est)
+ ##### Sampling from factor-analysis posteriors with known factors.
+
+ I.e.,
+
+ ```none
+ for i=1..n:
+ w[i] ~ Normal(0, eye(d)) # prior
+ x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood
```
+ where `F` denotes factors.
+
```python
- # Sampling from the posterior of a Bayesian regression model.:
-
- # Run 100 chains in parallel, each with a different initialization.
- initial_beta = tf.random_normal([100, x.shape[1]])
- chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta,
- log_joint_partial, event_dims=[1])
- # Discard first halves of chains as warmup/burn-in
- warmed_up = chain[500:]
- # Averaging across samples within a chain and across chains
- mean_est = tf.reduce_mean(warmed_up, [0, 1])
- var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est)
+ tfd = tf.contrib.distributions
+
+ def make_prior(dims, dtype):
+ return tfd.MultivariateNormalDiag(
+ loc=tf.zeros(dims, dtype))
+
+ def make_likelihood(weights, factors):
+ return tfd.MultivariateNormalDiag(
+ loc=tf.tensordot(weights, factors, axes=[[0], [-1]]))
+
+ # Setup data.
+ num_weights = 10
+ num_factors = 4
+ num_chains = 100
+ dtype = np.float32
+
+ prior = make_prior(num_weights, dtype)
+ weights = prior.sample(num_chains)
+ factors = np.random.randn(num_factors, num_weights).astype(dtype)
+ x = make_likelihood(weights, factors).sample(num_chains)
+
+ def target_log_prob(w):
+ # Target joint is: `f(w) = p(w, x | factors)`.
+ return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x)
+
+ # Get `num_results` samples from `num_chains` independent chains.
+ chains_states, kernels_results = hmc.sample_chain(
+ num_results=1000,
+ target_log_prob_fn=target_log_prob,
+ current_state=tf.zeros([num_chains, dims], dtype),
+ step_size=0.1,
+ num_leapfrog_steps=2,
+ num_burnin_steps=500)
+
+ # Compute sample stats.
+ sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
+ sample_var = tf.reduce_mean(
+ tf.squared_difference(chains_states, sample_mean),
+ axis=[0, 1])
```
+
+ Args:
+ num_results: Integer number of Markov chain draws.
+ target_log_prob_fn: Python callable which takes an argument like
+ `current_state` (or `*current_state` if it's a list) and returns its
+ (possibly unnormalized) log-density under the target distribution.
+ current_state: `Tensor` or Python `list` of `Tensor`s representing the
+ current state(s) of the Markov chain(s). The first `r` dimensions index
+ independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
+ step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
+ for the leapfrog integrator. Must broadcast with the shape of
+ `current_state`. Larger step sizes lead to faster progress, but too-large
+ step sizes make rejection exponentially more likely. When possible, it's
+ often helpful to match per-variable step sizes to the standard deviations
+ of the target distribution in each variable.
+ num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+ for. Total progress per HMC step is roughly proportional to `step_size *
+ num_leapfrog_steps`.
+ num_burnin_steps: Integer number of chain steps to take before starting to
+ collect results.
+ Default value: 0 (i.e., no burn-in).
+ num_steps_between_results: Integer number of chain steps between collecting
+ a result. Only one out of every `num_steps_between_samples + 1` steps is
+ included in the returned results. This "thinning" comes at a cost of
+ reduced statistical power, while reducing memory requirements and
+ autocorrelation. For more discussion see [1].
+ Default value: 0 (i.e., no subsampling).
+ seed: Python integer to seed the random number generator.
+ current_target_log_prob: (Optional) `Tensor` representing the value of
+ `target_log_prob_fn` at the `current_state`. The only reason to specify
+ this argument is to reduce TF graph size.
+ Default value: `None` (i.e., compute as needed).
+ current_grads_target_log_prob: (Optional) Python list of `Tensor`s
+ representing gradient of `target_log_prob` at the `current_state` and wrt
+ the `current_state`. Must have same shape as `current_state`. The only
+ reason to specify this argument is to reduce TF graph size.
+ Default value: `None` (i.e., compute as needed).
+ name: Python `str` name prefixed to Ops created by this function.
+ Default value: `None` (i.e., "hmc_sample_chain").
+
+ Returns:
+ accepted_states: Tensor or Python list of `Tensor`s representing the
+ state(s) of the Markov chain(s) at each result step. Has same shape as
+ input `current_state` but with a prepended `num_results`-size dimension.
+ kernel_results: `collections.namedtuple` of internal calculations used to
+ advance the chain.
"""
- with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size,
- n_leapfrog_steps, initial_x]):
- initial_x = ops.convert_to_tensor(initial_x, name='initial_x')
- non_event_shape = array_ops.shape(target_log_prob_fn(initial_x))
-
- def body(a, _):
- updated_x, acceptance_probs, log_prob, grad = kernel(
- step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims,
- a[2], a[3])
- return updated_x, acceptance_probs, log_prob, grad
-
- potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
- potential, grad = potential_and_grad(initial_x)
- return functional_ops.scan(
- body, array_ops.zeros(n_iterations, dtype=initial_x.dtype),
- (initial_x,
- array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
- -potential, -grad))[:2]
-
-
-def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
- target_log_prob_fn, proposal_log_prob_fn, event_dims=(),
- name=None):
+ with ops.name_scope(
+ name, "hmc_sample_chain",
+ [num_results, current_state, step_size, num_leapfrog_steps,
+ num_burnin_steps, num_steps_between_results, seed,
+ current_target_log_prob, current_grads_target_log_prob]):
+ with ops.name_scope("initialize"):
+ [
+ current_state,
+ step_size,
+ current_target_log_prob,
+ current_grads_target_log_prob,
+ ] = _prepare_args(
+ target_log_prob_fn,
+ current_state,
+ step_size,
+ current_target_log_prob,
+ current_grads_target_log_prob)
+ num_results = ops.convert_to_tensor(
+ num_results,
+ dtype=dtypes.int32,
+ name="num_results")
+ num_leapfrog_steps = ops.convert_to_tensor(
+ num_leapfrog_steps,
+ dtype=dtypes.int32,
+ name="num_leapfrog_steps")
+ num_burnin_steps = ops.convert_to_tensor(
+ num_burnin_steps,
+ dtype=dtypes.int32,
+ name="num_burnin_steps")
+ num_steps_between_results = ops.convert_to_tensor(
+ num_steps_between_results,
+ dtype=dtypes.int32,
+ name="num_steps_between_results")
+
+ def _run_chain(num_steps, current_state, kernel_results):
+ """Runs the chain(s) for `num_steps`."""
+ def _loop_body(iter_, current_state, kernel_results):
+ return [iter_ + 1] + list(kernel(
+ target_log_prob_fn,
+ current_state,
+ step_size,
+ num_leapfrog_steps,
+ seed,
+ kernel_results.current_target_log_prob,
+ kernel_results.current_grads_target_log_prob))
+ while_loop_kwargs = dict(
+ cond=lambda iter_, *args: iter_ < num_steps,
+ body=_loop_body,
+ loop_vars=[
+ np.int32(0),
+ current_state,
+ kernel_results,
+ ],
+ )
+ if seed is not None:
+ while_loop_kwargs["parallel_iterations"] = 1
+ return control_flow_ops.while_loop(
+ **while_loop_kwargs)[1:] # Lop-off "iter_".
+
+ def _scan_body(args_list, iter_):
+ """Closure which implements `tf.scan` body."""
+ current_state, kernel_results = args_list
+ return _run_chain(
+ 1 + array_ops.where(math_ops.equal(iter_, 0),
+ num_burnin_steps,
+ num_steps_between_results),
+ current_state,
+ kernel_results)
+
+ scan_kwargs = dict(
+ fn=_scan_body,
+ elems=math_ops.range(num_results), # iter_: used to choose burnin.
+ initializer=[
+ current_state,
+ _make_dummy_kernel_results(
+ current_state,
+ current_target_log_prob,
+ current_grads_target_log_prob),
+ ])
+ if seed is not None:
+ scan_kwargs["parallel_iterations"] = 1
+ return functional_ops.scan(**scan_kwargs)
+
+
+def sample_annealed_importance_chain(
+ proposal_log_prob_fn,
+ num_steps,
+ target_log_prob_fn,
+ current_state,
+ step_size,
+ num_leapfrog_steps,
+ seed=None,
+ name=None):
"""Runs annealed importance sampling (AIS) to estimate normalizing constants.
- This routine uses Hamiltonian Monte Carlo to sample from a series of
+ This function uses Hamiltonian Monte Carlo to sample from a series of
distributions that slowly interpolates between an initial "proposal"
- distribution
+ distribution:
`exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`
- and the target distribution
+ and the target distribution:
`exp(target_log_prob_fn(x) - target_log_normalizer)`,
@@ -202,113 +359,203 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x,
normalizing constants of the initial distribution and the target
distribution:
- E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer).
+ `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.
- Args:
- n_iterations: Integer number of Markov chain updates to run. More
- iterations means more expense, but smoother annealing between q
- and p, which in turn means exponentially lower variance for the
- normalizing constant estimator.
- step_size: Scalar step size or array of step sizes for the
- leapfrog integrator. Broadcasts to the shape of
- `initial_x`. Larger step sizes lead to faster progress, but
- too-large step sizes make rejection exponentially more likely.
- When possible, it's often helpful to match per-variable step
- sizes to the standard deviations of the target distribution in
- each variable.
- n_leapfrog_steps: Integer number of steps to run the leapfrog
- integrator for. Total progress per HMC step is roughly
- proportional to step_size * n_leapfrog_steps.
- initial_x: Tensor of initial state(s) of the Markov chain(s). Must
- be a sample from q, or results will be incorrect.
- target_log_prob_fn: Python callable which takes an argument like `initial_x`
- and returns its (possibly unnormalized) log-density under the target
- distribution.
- proposal_log_prob_fn: Python callable that returns the log density of the
- initial distribution.
- event_dims: List of dimensions that should not be treated as
- independent. This allows for multiple chains to be run independently
- in parallel. Default is (), i.e., all dimensions are independent.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- ais_weights: Tensor with the estimated weight(s). Has shape matching
- `target_log_prob_fn(initial_x)`.
- chain_states: Tensor with the state(s) of the Markov chain(s) the final
- iteration. Has shape matching `initial_x`.
- acceptance_probs: Tensor with the acceptance probabilities for the final
- iteration. Has shape matching `target_log_prob_fn(initial_x)`.
+ Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three
+ times (although this may be reduced to two times, in the future).
#### Examples:
+ ##### Estimate the normalizing constant of a log-gamma distribution.
+
```python
- # Estimating the normalizing constant of a log-gamma distribution:
- def proposal_log_prob(x):
- # Standard normal log-probability. This is properly normalized.
- return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1)
- def target_log_prob(x):
- # Unnormalized log-gamma(2, 3) distribution.
- # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1]
- return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1)
+ tfd = tf.contrib.distributions
+
# Run 100 AIS chains in parallel
- initial_x = tf.random_normal([100, 20])
- w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob,
- proposal_log_prob, event_dims=[1])
- log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100)
+ num_chains = 100
+ dims = 20
+ dtype = np.float32
+
+ proposal = tfd.MultivatiateNormalDiag(
+ loc=tf.zeros([dims], dtype=dtype))
+
+ target = tfd.TransformedDistribution(
+ distribution=tfd.Gamma(concentration=dtype(2),
+ rate=dtype(3)),
+ bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()),
+ event_shape=[dims])
+
+ chains_state, ais_weights, kernels_results = (
+ hmc.sample_annealed_importance_chain(
+ proposal_log_prob_fn=proposal.log_prob,
+ num_steps=1000,
+ target_log_prob_fn=target.log_prob,
+ step_size=0.2,
+ current_state=proposal.sample(num_chains),
+ num_leapfrog_steps=2))
+
+ log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
+ - np.log(num_chains))
+ log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
```
+ ##### Estimate marginal likelihood of a Bayesian regression model.
+
```python
- # Estimating the marginal likelihood of a Bayesian regression model:
- base_measure = -0.5 * np.log(2 * np.pi)
- def proposal_log_prob(x):
- # Standard normal log-probability. This is properly normalized.
- return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1)
- def regression_log_joint(beta, x, y):
- # This function returns a vector whose ith element is log p(beta[i], y | x).
- # Each row of beta corresponds to the state of an independent Markov chain.
- log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1)
- means = tf.matmul(beta, x, transpose_b=True)
- log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) +
- base_measure, 1)
- return log_prior + log_likelihood
- def log_joint_partial(beta):
- return regression_log_joint(beta, x, y)
+ tfd = tf.contrib.distributions
+
+ def make_prior(dims, dtype):
+ return tfd.MultivariateNormalDiag(
+ loc=tf.zeros(dims, dtype))
+
+ def make_likelihood(weights, x):
+ return tfd.MultivariateNormalDiag(
+ loc=tf.tensordot(weights, x, axes=[[0], [-1]]))
+
# Run 100 AIS chains in parallel
- initial_beta = tf.random_normal([100, x.shape[1]])
- w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta,
- log_joint_partial, proposal_log_prob,
- event_dims=[1])
- log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100)
+ num_chains = 100
+ dims = 10
+ dtype = np.float32
+
+ # Make training data.
+ x = np.random.randn(num_chains, dims).astype(dtype)
+ true_weights = np.random.randn(dims).astype(dtype)
+ y = np.dot(x, true_weights) + np.random.randn(num_chains)
+
+ # Setup model.
+ prior = make_prior(dims, dtype)
+ def target_log_prob_fn(weights):
+ return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)
+
+ proposal = tfd.MultivariateNormalDiag(
+ loc=tf.zeros(dims, dtype))
+
+ weight_samples, ais_weights, kernel_results = (
+ hmc.sample_annealed_importance_chain(
+ num_steps=1000,
+ proposal_log_prob_fn=proposal.log_prob,
+ target_log_prob_fn=target_log_prob_fn
+ current_state=tf.zeros([num_chains, dims], dtype),
+ step_size=0.1,
+ num_leapfrog_steps=2))
+ log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
+ - np.log(num_chains))
```
+
+ Args:
+ proposal_log_prob_fn: Python callable that returns the log density of the
+ initial distribution.
+ num_steps: Integer number of Markov chain updates to run. More
+ iterations means more expense, but smoother annealing between q
+ and p, which in turn means exponentially lower variance for the
+ normalizing constant estimator.
+ target_log_prob_fn: Python callable which takes an argument like
+ `current_state` (or `*current_state` if it's a list) and returns its
+ (possibly unnormalized) log-density under the target distribution.
+ current_state: `Tensor` or Python `list` of `Tensor`s representing the
+ current state(s) of the Markov chain(s). The first `r` dimensions index
+ independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
+ step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
+ for the leapfrog integrator. Must broadcast with the shape of
+ `current_state`. Larger step sizes lead to faster progress, but too-large
+ step sizes make rejection exponentially more likely. When possible, it's
+ often helpful to match per-variable step sizes to the standard deviations
+ of the target distribution in each variable.
+ num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+ for. Total progress per HMC step is roughly proportional to `step_size *
+ num_leapfrog_steps`.
+ seed: Python integer to seed the random number generator.
+ name: Python `str` name prefixed to Ops created by this function.
+ Default value: `None` (i.e., "hmc_sample_annealed_importance_chain").
+
+ Returns:
+ accepted_state: `Tensor` or Python list of `Tensor`s representing the
+ state(s) of the Markov chain(s) at the final iteration. Has same shape as
+ input `current_state`.
+ ais_weights: Tensor with the estimated weight(s). Has shape matching
+ `target_log_prob_fn(current_state)`.
+ kernel_results: `collections.namedtuple` of internal calculations used to
+ advance the chain.
"""
- with ops.name_scope(name, 'hmc_ais_chain',
- [n_iterations, step_size, n_leapfrog_steps, initial_x]):
- non_event_shape = array_ops.shape(target_log_prob_fn(initial_x))
-
- beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:]
- def _body(a, beta): # pylint: disable=missing-docstring
- def log_prob_beta(x):
- return ((1 - beta) * proposal_log_prob_fn(x) +
- beta * target_log_prob_fn(x))
- last_x = a[0]
- w = a[2]
- w += (1. / n_iterations) * (target_log_prob_fn(last_x) -
- proposal_log_prob_fn(last_x))
- # TODO(b/66917083): There's an opportunity for gradient reuse here.
- updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps,
- last_x, log_prob_beta,
- event_dims)
- return updated_x, acceptance_probs, w
-
- x, acceptance_probs, w = functional_ops.scan(
- _body, beta_series,
- (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype),
- array_ops.zeros(non_event_shape, dtype=initial_x.dtype)))
- return w[-1], x[-1], acceptance_probs[-1]
-
-
-def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
- x_log_prob=None, x_grad=None, name=None):
+ def make_convex_combined_log_prob_fn(iter_):
+ def _fn(*args):
+ p = proposal_log_prob_fn(*args)
+ t = target_log_prob_fn(*args)
+ dtype = p.dtype.base_dtype
+ beta = (math_ops.cast(iter_ + 1, dtype)
+ / math_ops.cast(num_steps, dtype))
+ return (1. - beta) * p + beta * t
+ return _fn
+
+ with ops.name_scope(
+ name, "hmc_sample_annealed_importance_chain",
+ [num_steps, current_state, step_size, num_leapfrog_steps, seed]):
+ with ops.name_scope("initialize"):
+ [
+ current_state,
+ step_size,
+ current_log_prob,
+ current_grads_log_prob,
+ ] = _prepare_args(
+ make_convex_combined_log_prob_fn(iter_=0),
+ current_state,
+ step_size,
+ description="convex_combined_log_prob")
+ num_steps = ops.convert_to_tensor(
+ num_steps,
+ dtype=dtypes.int32,
+ name="num_steps")
+ num_leapfrog_steps = ops.convert_to_tensor(
+ num_leapfrog_steps,
+ dtype=dtypes.int32,
+ name="num_leapfrog_steps")
+ def _loop_body(iter_, ais_weights, current_state, kernel_results):
+ """Closure which implements `tf.while_loop` body."""
+ current_state_parts = (list(current_state)
+ if _is_list_like(current_state)
+ else [current_state])
+ # TODO(b/72994218): Consider refactoring things to avoid this unecessary
+ # call.
+ ais_weights += ((target_log_prob_fn(*current_state_parts)
+ - proposal_log_prob_fn(*current_state_parts))
+ / math_ops.cast(num_steps, ais_weights.dtype))
+ return [iter_ + 1, ais_weights] + list(kernel(
+ make_convex_combined_log_prob_fn(iter_),
+ current_state,
+ step_size,
+ num_leapfrog_steps,
+ seed,
+ kernel_results.current_target_log_prob,
+ kernel_results.current_grads_target_log_prob))
+
+ while_loop_kwargs = dict(
+ cond=lambda iter_, *args: iter_ < num_steps,
+ body=_loop_body,
+ loop_vars=[
+ np.int32(0), # iter_
+ array_ops.zeros_like(current_log_prob), # ais_weights
+ current_state,
+ _make_dummy_kernel_results(current_state,
+ current_log_prob,
+ current_grads_log_prob),
+ ])
+ if seed is not None:
+ while_loop_kwargs["parallel_iterations"] = 1
+
+ [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop(
+ **while_loop_kwargs)[1:] # Lop-off "iter_".
+
+ return [current_state, ais_weights, kernel_results]
+
+
+def kernel(target_log_prob_fn,
+ current_state,
+ step_size,
+ num_leapfrog_steps,
+ seed=None,
+ current_target_log_prob=None,
+ current_grads_target_log_prob=None,
+ name=None):
"""Runs one iteration of Hamiltonian Monte Carlo.
Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
@@ -316,334 +563,623 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(),
a Metropolis proposal. This function applies one step of HMC to
randomly update the variable `x`.
- This function can update multiple chains in parallel. It assumes
- that all dimensions of `x` not specified in `event_dims` are
- independent, and should therefore be updated independently. The
- output of `target_log_prob_fn()` should sum log-probabilities across
- all event dimensions. Slices along dimensions not in `event_dims`
- may have different target distributions; for example, if
- `event_dims == (1,)`, then `x[0, :]` could have a different target
- distribution from x[1, :]. This is up to `target_log_prob_fn()`.
-
- Args:
- step_size: Scalar step size or array of step sizes for the
- leapfrog integrator. Broadcasts to the shape of
- `x`. Larger step sizes lead to faster progress, but
- too-large step sizes make rejection exponentially more likely.
- When possible, it's often helpful to match per-variable step
- sizes to the standard deviations of the target distribution in
- each variable.
- n_leapfrog_steps: Integer number of steps to run the leapfrog
- integrator for. Total progress per HMC step is roughly
- proportional to step_size * n_leapfrog_steps.
- x: Tensor containing the value(s) of the random variable(s) to update.
- target_log_prob_fn: Python callable which takes an argument like `initial_x`
- and returns its (possibly unnormalized) log-density under the target
- distribution.
- event_dims: List of dimensions that should not be treated as
- independent. This allows for multiple chains to be run independently
- in parallel. Default is (), i.e., all dimensions are independent.
- x_log_prob (optional): Tensor containing the cached output of a previous
- call to `target_log_prob_fn()` evaluated at `x` (such as that provided by
- a previous call to `kernel()`). Providing `x_log_prob` and
- `x_grad` saves one gradient computation per call to `kernel()`.
- x_grad (optional): Tensor containing the cached gradient of
- `target_log_prob_fn()` evaluated at `x` (such as that provided by
- a previous call to `kernel()`). Providing `x_log_prob` and
- `x_grad` saves one gradient computation per call to `kernel()`.
- name: Python `str` name prefixed to Ops created by this function.
-
- Returns:
- updated_x: The updated variable(s) x. Has shape matching `initial_x`.
- acceptance_probs: Tensor with the acceptance probabilities for the final
- iteration. This is useful for diagnosing step size problems etc. Has
- shape matching `target_log_prob_fn(initial_x)`.
- new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`.
- new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at
- `updated_x`.
+ This function can update multiple chains in parallel. It assumes that all
+ leftmost dimensions of `current_state` index independent chain states (and are
+ therefore updated independently). The output of `target_log_prob_fn()` should
+ sum log-probabilities across all event dimensions. Slices along the rightmost
+ dimensions may have different target distributions; for example,
+ `current_state[0, :]` could have a different target distribution from
+ `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of
+ independent chains is `tf.size(target_log_prob_fn(*current_state))`.)
#### Examples:
+ ##### Simple chain with warm-up.
+
```python
+ tfd = tf.contrib.distributions
+
# Tuning acceptance rates:
+ dtype = np.float32
target_accept_rate = 0.631
- def target_log_prob(x):
- # Standard normal
- return tf.reduce_sum(-0.5 * tf.square(x))
- initial_x = tf.zeros([10])
- initial_log_prob = target_log_prob(initial_x)
- initial_grad = tf.gradients(initial_log_prob, initial_x)[0]
- # Algorithm state
- x = tf.Variable(initial_x, name='x')
- step_size = tf.Variable(1., name='step_size')
- last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob')
- last_grad = tf.Variable(initial_grad, name='last_grad')
- # Compute updates
- new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x,
- target_log_prob,
- event_dims=[0],
- x_log_prob=last_log_prob)
- x_update = tf.assign(x, new_x)
- log_prob_update = tf.assign(last_log_prob, log_prob)
- grad_update = tf.assign(last_grad, grad)
- step_size_update = tf.assign(step_size,
- tf.where(acceptance_prob > target_accept_rate,
- step_size * 1.01, step_size / 1.01))
- adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update]
- sampling_updates = [x_update, log_prob_update, grad_update]
-
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
+ num_warmup_iter = 500
+ num_chain_iter = 500
+
+ x = tf.get_variable(name="x", initializer=dtype(1))
+ step_size = tf.get_variable(name="step_size", initializer=dtype(1))
+
+ target = tfd.Normal(loc=dtype(0), scale=dtype(1))
+
+ new_x, other_results = hmc.kernel(
+ target_log_prob_fn=target.log_prob,
+ current_state=x,
+ step_size=step_size,
+ num_leapfrog_steps=3)[:4]
+
+ x_update = x.assign(new_x)
+
+ step_size_update = step_size.assign_add(
+ step_size * tf.where(
+ other_results.acceptance_probs > target_accept_rate,
+ 0.01, -0.01))
+
+ warmup = tf.group([x_update, step_size_update])
+
+ tf.global_variables_initializer().run()
+
+ sess.graph.finalize() # No more graph building.
+
# Warm up the sampler and adapt the step size
- for i in xrange(500):
- sess.run(adaptive_updates)
- # Collect samples without adapting step size
- samples = np.zeros([500, 10])
- for i in xrange(500):
- x_val, _ = sess.run([new_x, sampling_updates])
- samples[i] = x_val
- ```
+ for _ in xrange(num_warmup_iter):
+ sess.run(warmup)
- ```python
- # Empirical-Bayes estimation of a hyperparameter by MCMC-EM:
-
- # Problem setup
- N = 150
- D = 10
- x = np.random.randn(N, D).astype(np.float32)
- true_sigma = 0.5
- true_beta = true_sigma * np.random.randn(D).astype(np.float32)
- y = x.dot(true_beta) + np.random.randn(N).astype(np.float32)
-
- def log_prior(beta, log_sigma):
- return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) -
- log_sigma)
- def regression_log_joint(beta, log_sigma, x, y):
- # This function returns log p(beta | log_sigma) + log p(y | x, beta).
- means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True)
- means = tf.squeeze(means)
- log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means))
- return log_prior(beta, log_sigma) + log_likelihood
- def log_joint_partial(beta):
- return regression_log_joint(beta, log_sigma, x, y)
- # Our estimate of log(sigma)
- log_sigma = tf.Variable(0., name='log_sigma')
- # The state of the Markov chain
- beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta')
- new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial,
- event_dims=[0])
- beta_update = tf.assign(beta, new_beta)
- optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
- with tf.control_dependencies([beta_update]):
- log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma),
- var_list=[log_sigma])
-
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
- log_sigma_history = np.zeros(1000)
- for i in xrange(1000):
- log_sigma_val, _ = sess.run([log_sigma, log_sigma_update])
- log_sigma_history[i] = log_sigma_val
- # Should converge to something close to true_sigma
- plt.plot(np.exp(log_sigma_history))
+ # Collect samples without adapting step size
+ samples = np.zeros([num_chain_iter])
+ for i in xrange(num_chain_iter):
+ _, x_, target_log_prob_, grad_ = sess.run([
+ x_update,
+ x,
+ other_results.target_log_prob,
+ other_results.grads_target_log_prob])
+ samples[i] = x_
+
+ print(samples.mean(), samples.std())
```
- """
- with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]):
- potential_and_grad = _make_potential_and_grad(target_log_prob_fn)
- x = ops.convert_to_tensor(x, name='x')
-
- x_shape = array_ops.shape(x)
- m = random_ops.random_normal(x_shape, dtype=x.dtype)
-
- kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims)
-
- if (x_log_prob is not None) and (x_grad is not None):
- log_potential_0, grad_0 = -x_log_prob, -x_grad # pylint: disable=invalid-unary-operand-type
- else:
- if x_log_prob is not None:
- logging.warn('x_log_prob was provided, but x_grad was not,'
- ' so x_log_prob was not used.')
- if x_grad is not None:
- logging.warn('x_grad was provided, but x_log_prob was not,'
- ' so x_grad was not used.')
- log_potential_0, grad_0 = potential_and_grad(x)
-
- new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator(
- step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0)
-
- kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims)
-
- energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0
- # Treat NaN as infinite energy (and therefore guaranteed rejection).
- energy_change = array_ops.where(
- math_ops.is_nan(energy_change),
- array_ops.fill(array_ops.shape(energy_change),
- energy_change.dtype.as_numpy_dtype(np.inf)),
- energy_change)
- acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.))
- accepted = (
- random_ops.random_uniform(
- array_ops.shape(acceptance_probs), dtype=x.dtype)
- < acceptance_probs)
- new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0)
-
- # TODO(b/65738010): This should work, but it doesn't for now.
- # reduced_shape = math_ops.reduced_shape(x_shape, event_dims)
- reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims,
- keep_dims=True))
- accepted = array_ops.reshape(accepted, reduced_shape)
- accepted = math_ops.logical_or(
- accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool))
- new_x = array_ops.where(accepted, new_x, x)
- new_grad = -array_ops.where(accepted, grad_1, grad_0)
-
- # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect
- # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This
- # should be fixed.
- return new_x, acceptance_probs, new_log_prob, new_grad
-
-
-def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum,
- potential_and_grad, initial_grad, name=None):
- """Applies `n_steps` steps of the leapfrog integrator.
-
- This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing
- gradient computations where possible.
- Args:
- step_size: Scalar step size or array of step sizes for the
- leapfrog integrator. Broadcasts to the shape of
- `initial_position`. Larger step sizes lead to faster progress, but
- too-large step sizes lead to larger discretization error and
- worse energy conservation.
- n_steps: Number of steps to run the leapfrog integrator.
- initial_position: Tensor containing the value(s) of the position variable(s)
- to update.
- initial_momentum: Tensor containing the value(s) of the momentum variable(s)
- to update.
- potential_and_grad: Python callable that takes a position tensor like
- `initial_position` and returns the potential energy and its gradient at
- that position.
- initial_grad: Tensor with the value of the gradient of the potential energy
- at `initial_position`.
- name: Python `str` name prefixed to Ops created by this function.
+ ##### Sample from more complicated posterior.
- Returns:
- updated_position: Updated value of the position.
- updated_momentum: Updated value of the momentum.
- new_potential: Potential energy of the new position. Has shape matching
- `potential_and_grad(initial_position)`.
- new_grad: Gradient from potential_and_grad() evaluated at the new position.
- Has shape matching `initial_position`.
+ I.e.,
- Example: Simple quadratic potential.
+ ```none
+ W ~ MVN(loc=0, scale=sigma * eye(dims))
+ for i=1...num_samples:
+ X[i] ~ MVN(loc=0, scale=eye(dims))
+ eps[i] ~ Normal(loc=0, scale=1)
+ Y[i] = X[i].T * W + eps[i]
+ ```
```python
- def potential_and_grad(position):
- return tf.reduce_sum(0.5 * tf.square(position)), position
- position = tf.placeholder(np.float32)
- momentum = tf.placeholder(np.float32)
- potential, grad = potential_and_grad(position)
- new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator(
- 0.1, 3, position, momentum, potential_and_grad, grad)
-
- sess = tf.Session()
- position_val = np.random.randn(10)
- momentum_val = np.random.randn(10)
- potential_val, grad_val = sess.run([potential, grad],
- {position: position_val})
- positions = np.zeros([100, 10])
- for i in xrange(100):
- position_val, momentum_val, potential_val, grad_val = sess.run(
- [new_position, new_momentum, new_potential, new_grad],
- {position: position_val, momentum: momentum_val})
- positions[i] = position_val
- # Should trace out sinusoidal dynamics.
- plt.plot(positions[:, 0])
- ```
- """
- def leapfrog_wrapper(step_size, x, m, grad, l):
- x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad)
- return step_size, x, m, grad, l + 1
+ tfd = tf.contrib.distributions
+
+ def make_training_data(num_samples, dims, sigma):
+ dt = np.asarray(sigma).dtype
+ zeros = tf.zeros(dims, dtype=dt)
+ x = tfd.MultivariateNormalDiag(
+ loc=zeros).sample(num_samples, seed=1)
+ w = tfd.MultivariateNormalDiag(
+ loc=zeros,
+ scale_identity_multiplier=sigma).sample(seed=2)
+ noise = tfd.Normal(
+ loc=dt(0),
+ scale=dt(1)).sample(num_samples, seed=3)
+ y = tf.tensordot(x, w, axes=[[1], [0]]) + noise
+ return y, x, w
+
+ def make_prior(sigma, dims):
+ # p(w | sigma)
+ return tfd.MultivariateNormalDiag(
+ loc=tf.zeros([dims], dtype=sigma.dtype),
+ scale_identity_multiplier=sigma)
+
+ def make_likelihood(x, w):
+ # p(y | x, w)
+ return tfd.MultivariateNormalDiag(
+ loc=tf.tensordot(x, w, axes=[[1], [0]]))
+
+ # Setup assumptions.
+ dtype = np.float32
+ num_samples = 150
+ dims = 10
+ num_iters = int(5e3)
+
+ true_sigma = dtype(0.5)
+ y, x, true_weights = make_training_data(num_samples, dims, true_sigma)
+
+ # Estimate of `log(true_sigma)`.
+ log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0))
+ sigma = tf.exp(log_sigma)
+
+ # State of the Markov chain.
+ weights = tf.get_variable(
+ name="weights",
+ initializer=np.random.randn(dims).astype(dtype))
+
+ prior = make_prior(sigma, dims)
+
+ def joint_log_prob_fn(w):
+ # f(w) = log p(w, y | x)
+ return prior.log_prob(w) + make_likelihood(x, w).log_prob(y)
+
+ weights_update = weights.assign(
+ hmc.kernel(target_log_prob_fn=joint_log_prob,
+ current_state=weights,
+ step_size=0.1,
+ num_leapfrog_steps=5)[0])
+
+ with tf.control_dependencies([weights_update]):
+ loss = -prior.log_prob(weights)
+
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+ log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma])
+
+ sess.graph.finalize() # No more graph building.
- def counter_fn(a, b, c, d, counter): # pylint: disable=unused-argument
- return counter < n_steps
+ tf.global_variables_initializer().run()
- with ops.name_scope(name, 'leapfrog_integrator',
- [step_size, n_steps, initial_position, initial_momentum,
- initial_grad]):
- _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop(
- counter_fn, leapfrog_wrapper, [step_size, initial_position,
- initial_momentum, initial_grad,
- array_ops.constant(0)], back_prop=False)
- # We're counting on the runtime to eliminate this redundant computation.
- new_potential, new_grad = potential_and_grad(new_x)
- return new_x, new_m, new_potential, new_grad
+ sigma_history = np.zeros(num_iters, dtype)
+ weights_history = np.zeros([num_iters, dims], dtype)
+ for i in xrange(num_iters):
+ _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights])
+ weights_history[i, :] = weights_
+ sigma_history[i] = sigma_
-def leapfrog_step(step_size, position, momentum, potential_and_grad, grad,
- name=None):
- """Applies one step of the leapfrog integrator.
+ true_weights_ = sess.run(true_weights)
- Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2.
+ # Should converge to something close to true_sigma.
+ plt.plot(sigma_history);
+ plt.ylabel("sigma");
+ plt.xlabel("iteration");
+ ```
Args:
- step_size: Scalar step size or array of step sizes for the
- leapfrog integrator. Broadcasts to the shape of
- `position`. Larger step sizes lead to faster progress, but
- too-large step sizes lead to larger discretization error and
- worse energy conservation.
- position: Tensor containing the value(s) of the position variable(s)
- to update.
- momentum: Tensor containing the value(s) of the momentum variable(s)
- to update.
- potential_and_grad: Python callable that takes a position tensor like
- `position` and returns the potential energy and its gradient at that
- position.
- grad: Tensor with the value of the gradient of the potential energy
- at `position`.
+ target_log_prob_fn: Python callable which takes an argument like
+ `current_state` (or `*current_state` if it's a list) and returns its
+ (possibly unnormalized) log-density under the target distribution.
+ current_state: `Tensor` or Python `list` of `Tensor`s representing the
+ current state(s) of the Markov chain(s). The first `r` dimensions index
+ independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
+ step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
+ for the leapfrog integrator. Must broadcast with the shape of
+ `current_state`. Larger step sizes lead to faster progress, but too-large
+ step sizes make rejection exponentially more likely. When possible, it's
+ often helpful to match per-variable step sizes to the standard deviations
+ of the target distribution in each variable.
+ num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+ for. Total progress per HMC step is roughly proportional to `step_size *
+ num_leapfrog_steps`.
+ seed: Python integer to seed the random number generator.
+ current_target_log_prob: (Optional) `Tensor` representing the value of
+ `target_log_prob_fn` at the `current_state`. The only reason to
+ specify this argument is to reduce TF graph size.
+ Default value: `None` (i.e., compute as needed).
+ current_grads_target_log_prob: (Optional) Python list of `Tensor`s
+ representing gradient of `current_target_log_prob` at the `current_state`
+ and wrt the `current_state`. Must have same shape as `current_state`. The
+ only reason to specify this argument is to reduce TF graph size.
+ Default value: `None` (i.e., compute as needed).
name: Python `str` name prefixed to Ops created by this function.
+ Default value: `None` (i.e., "hmc_kernel").
Returns:
- updated_position: Updated value of the position.
- updated_momentum: Updated value of the momentum.
- new_potential: Potential energy of the new position. Has shape matching
- `potential_and_grad(position)`.
- new_grad: Gradient from potential_and_grad() evaluated at the new position.
- Has shape matching `position`.
+ accepted_state: Tensor or Python list of `Tensor`s representing the state(s)
+ of the Markov chain(s) at each result step. Has same shape as
+ `current_state`.
+ kernel_results: `collections.namedtuple` of internal calculations used to
+ advance the chain.
+
+ Raises:
+ ValueError: if there isn't one `step_size` or a list with same length as
+ `current_state`.
+ """
+ with ops.name_scope(
+ name, "hmc_kernel",
+ [current_state, step_size, num_leapfrog_steps, seed,
+ current_target_log_prob, current_grads_target_log_prob]):
+ with ops.name_scope("initialize"):
+ [current_state_parts, step_sizes, current_target_log_prob,
+ current_grads_target_log_prob] = _prepare_args(
+ target_log_prob_fn, current_state, step_size,
+ current_target_log_prob, current_grads_target_log_prob,
+ maybe_expand=True)
+ independent_chain_ndims = distributions_util.prefer_static_rank(
+ current_target_log_prob)
+ current_momentums = []
+ for s in current_state_parts:
+ current_momentums.append(random_ops.random_normal(
+ shape=array_ops.shape(s),
+ dtype=s.dtype.base_dtype,
+ seed=seed))
+ seed = distributions_util.gen_new_seed(
+ seed, salt="hmc_kernel_momentums")
+
+ num_leapfrog_steps = ops.convert_to_tensor(
+ num_leapfrog_steps,
+ dtype=dtypes.int32,
+ name="num_leapfrog_steps")
+ [
+ proposed_momentums,
+ proposed_state_parts,
+ proposed_target_log_prob,
+ proposed_grads_target_log_prob,
+ ] = _leapfrog_integrator(current_momentums,
+ target_log_prob_fn,
+ current_state_parts,
+ step_sizes,
+ num_leapfrog_steps,
+ current_target_log_prob,
+ current_grads_target_log_prob)
+
+ energy_change = _compute_energy_change(current_target_log_prob,
+ current_momentums,
+ proposed_target_log_prob,
+ proposed_momentums,
+ independent_chain_ndims)
+
+ # u < exp(min(-energy, 0)), where u~Uniform[0,1)
+ # ==> -log(u) >= max(e, 0)
+ # ==> -log(u) >= e
+ # (Perhaps surprisingly, we don't have a better way to obtain a random
+ # uniform from positive reals, i.e., `tf.random_uniform(minval=0,
+ # maxval=np.inf)` won't work.)
+ random_uniform = random_ops.random_uniform(
+ shape=array_ops.shape(energy_change),
+ dtype=energy_change.dtype,
+ seed=seed)
+ random_positive = -math_ops.log(random_uniform)
+ is_accepted = random_positive >= energy_change
+
+ accepted_target_log_prob = array_ops.where(is_accepted,
+ proposed_target_log_prob,
+ current_target_log_prob)
+
+ accepted_state_parts = [_choose(is_accepted,
+ proposed_state_part,
+ current_state_part,
+ independent_chain_ndims)
+ for current_state_part, proposed_state_part
+ in zip(current_state_parts, proposed_state_parts)]
+
+ accepted_grads_target_log_prob = [
+ _choose(is_accepted,
+ proposed_grad,
+ grad,
+ independent_chain_ndims)
+ for proposed_grad, grad
+ in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)]
+
+ maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0]
+ return [
+ maybe_flatten(accepted_state_parts),
+ KernelResults(
+ acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)),
+ current_grads_target_log_prob=accepted_grads_target_log_prob,
+ current_target_log_prob=accepted_target_log_prob,
+ energy_change=energy_change,
+ is_accepted=is_accepted,
+ proposed_grads_target_log_prob=proposed_grads_target_log_prob,
+ proposed_state=maybe_flatten(proposed_state_parts),
+ proposed_target_log_prob=proposed_target_log_prob,
+ random_positive=random_positive,
+ ),
+ ]
+
+
+def _leapfrog_integrator(current_momentums,
+ target_log_prob_fn,
+ current_state_parts,
+ step_sizes,
+ num_leapfrog_steps,
+ current_target_log_prob=None,
+ current_grads_target_log_prob=None,
+ name=None):
+ """Applies `num_leapfrog_steps` of the leapfrog integrator.
+
+ Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`.
+
+ #### Examples:
- Example: Simple quadratic potential.
+ ##### Simple quadratic potential.
```python
- def potential_and_grad(position):
- # Simple quadratic potential
- return tf.reduce_sum(0.5 * tf.square(position)), position
+ tfd = tf.contrib.distributions
+
+ dims = 10
+ num_iter = int(1e3)
+ dtype = np.float32
+
position = tf.placeholder(np.float32)
momentum = tf.placeholder(np.float32)
- potential, grad = potential_and_grad(position)
- new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step(
- 0.1, position, momentum, potential_and_grad, grad)
-
- sess = tf.Session()
- position_val = np.random.randn(10)
- momentum_val = np.random.randn(10)
- potential_val, grad_val = sess.run([potential, grad],
- {position: position_val})
- positions = np.zeros([100, 10])
- for i in xrange(100):
- position_val, momentum_val, potential_val, grad_val = sess.run(
- [new_position, new_momentum, new_potential, new_grad],
- {position: position_val, momentum: momentum_val})
- positions[i] = position_val
- # Should trace out sinusoidal dynamics.
- plt.plot(positions[:, 0])
+
+ [
+ new_momentums,
+ new_positions,
+ ] = hmc._leapfrog_integrator(
+ current_momentums=[momentum],
+ target_log_prob_fn=tfd.MultivariateNormalDiag(
+ loc=tf.zeros(dims, dtype)).log_prob,
+ current_state_parts=[position],
+ step_sizes=0.1,
+ num_leapfrog_steps=3)[:2]
+
+ sess.graph.finalize() # No more graph building.
+
+ momentum_ = np.random.randn(dims).astype(dtype)
+ position_ = np.random.randn(dims).astype(dtype)
+
+ positions = np.zeros([num_iter, dims], dtype)
+ for i in xrange(num_iter):
+ position_, momentum_ = sess.run(
+ [new_momentums[0], new_position[0]],
+ feed_dict={position: position_, momentum: momentum_})
+ positions[i] = position_
+
+ plt.plot(positions[:, 0]); # Sinusoidal.
```
+
+ Args:
+ current_momentums: Tensor containing the value(s) of the momentum
+ variable(s) to update.
+ target_log_prob_fn: Python callable which takes an argument like
+ `*current_state_parts` and returns its (possibly unnormalized) log-density
+ under the target distribution.
+ current_state_parts: Python `list` of `Tensor`s representing the current
+ state(s) of the Markov chain(s). The first `independent_chain_ndims` of
+ the `Tensor`(s) index different chains.
+ step_sizes: Python `list` of `Tensor`s representing the step size for the
+ leapfrog integrator. Must broadcast with the shape of
+ `current_state_parts`. Larger step sizes lead to faster progress, but
+ too-large step sizes make rejection exponentially more likely. When
+ possible, it's often helpful to match per-variable step sizes to the
+ standard deviations of the target distribution in each variable.
+ num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
+ for. Total progress per HMC step is roughly proportional to `step_size *
+ num_leapfrog_steps`.
+ current_target_log_prob: (Optional) `Tensor` representing the value of
+ `target_log_prob_fn(*current_state_parts)`. The only reason to specify
+ this argument is to reduce TF graph size.
+ Default value: `None` (i.e., compute as needed).
+ current_grads_target_log_prob: (Optional) Python list of `Tensor`s
+ representing gradient of `target_log_prob_fn(*current_state_parts`) wrt
+ `current_state_parts`. Must have same shape as `current_state_parts`. The
+ only reason to specify this argument is to reduce TF graph size.
+ Default value: `None` (i.e., compute as needed).
+ name: Python `str` name prefixed to Ops created by this function.
+ Default value: `None` (i.e., "hmc_leapfrog_integrator").
+
+ Returns:
+ proposed_momentums: Updated value of the momentum.
+ proposed_state_parts: Tensor or Python list of `Tensor`s representing the
+ state(s) of the Markov chain(s) at each result step. Has same shape as
+ input `current_state_parts`.
+ proposed_target_log_prob: `Tensor` representing the value of
+ `target_log_prob_fn` at `accepted_state`.
+ proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt
+ `accepted_state`.
+
+ Raises:
+ ValueError: if `len(momentums) != len(state_parts)`.
+ ValueError: if `len(state_parts) != len(step_sizes)`.
+ ValueError: if `len(state_parts) != len(grads_target_log_prob)`.
+ TypeError: if `not target_log_prob.dtype.is_floating`.
"""
- with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum,
- grad]):
- momentum -= 0.5 * step_size * grad
- position += step_size * momentum
- potential, grad = potential_and_grad(position)
- momentum -= 0.5 * step_size * grad
-
- return position, momentum, potential, grad
+ def _loop_body(step,
+ current_momentums,
+ current_state_parts,
+ ignore_current_target_log_prob, # pylint: disable=unused-argument
+ current_grads_target_log_prob):
+ return [step + 1] + list(_leapfrog_step(current_momentums,
+ target_log_prob_fn,
+ current_state_parts,
+ step_sizes,
+ current_grads_target_log_prob))
+
+ with ops.name_scope(
+ name, "hmc_leapfrog_integrator",
+ [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps,
+ current_target_log_prob, current_grads_target_log_prob]):
+ if len(current_momentums) != len(current_state_parts):
+ raise ValueError("`momentums` must be in one-to-one correspondence "
+ "with `state_parts`")
+ num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps,
+ name="num_leapfrog_steps")
+ current_target_log_prob, current_grads_target_log_prob = (
+ _maybe_call_fn_and_grads(
+ target_log_prob_fn,
+ current_state_parts,
+ current_target_log_prob,
+ current_grads_target_log_prob))
+ return control_flow_ops.while_loop(
+ cond=lambda iter_, *args: iter_ < num_leapfrog_steps,
+ body=_loop_body,
+ loop_vars=[
+ np.int32(0), # iter_
+ current_momentums,
+ current_state_parts,
+ current_target_log_prob,
+ current_grads_target_log_prob,
+ ],
+ back_prop=False)[1:] # Lop-off "iter_".
+
+
+def _leapfrog_step(current_momentums,
+ target_log_prob_fn,
+ current_state_parts,
+ step_sizes,
+ current_grads_target_log_prob,
+ name=None):
+ """Applies one step of the leapfrog integrator."""
+ with ops.name_scope(
+ name, "_leapfrog_step",
+ [current_momentums, current_state_parts, step_sizes,
+ current_grads_target_log_prob]):
+ proposed_momentums = [m + 0.5 * ss * g for m, ss, g
+ in zip(current_momentums,
+ step_sizes,
+ current_grads_target_log_prob)]
+ proposed_state_parts = [x + ss * m for x, ss, m
+ in zip(current_state_parts,
+ step_sizes,
+ proposed_momentums)]
+ proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts)
+ if not proposed_target_log_prob.dtype.is_floating:
+ raise TypeError("`target_log_prob_fn` must produce a `Tensor` "
+ "with `float` `dtype`.")
+ proposed_grads_target_log_prob = gradients_ops.gradients(
+ proposed_target_log_prob, proposed_state_parts)
+ if any(g is None for g in proposed_grads_target_log_prob):
+ raise ValueError(
+ "Encountered `None` gradient. Does your target `target_log_prob_fn` "
+ "access all `tf.Variable`s via `tf.get_variable`?\n"
+ " current_state_parts: {}\n"
+ " proposed_state_parts: {}\n"
+ " proposed_grads_target_log_prob: {}".format(
+ current_state_parts,
+ proposed_state_parts,
+ proposed_grads_target_log_prob))
+ proposed_momentums = [m + 0.5 * ss * g for m, ss, g
+ in zip(proposed_momentums,
+ step_sizes,
+ proposed_grads_target_log_prob)]
+ return [
+ proposed_momentums,
+ proposed_state_parts,
+ proposed_target_log_prob,
+ proposed_grads_target_log_prob,
+ ]
+
+
+def _compute_energy_change(current_target_log_prob,
+ current_momentums,
+ proposed_target_log_prob,
+ proposed_momentums,
+ independent_chain_ndims,
+ name=None):
+ """Helper to `kernel` which computes the energy change."""
+ with ops.name_scope(
+ name, "compute_energy_change",
+ ([current_target_log_prob, proposed_target_log_prob,
+ independent_chain_ndims] +
+ current_momentums + proposed_momentums)):
+ # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy
+ # since they're a mouthful and lets us inline more.
+ lk0, lk1 = [], []
+ for current_momentum, proposed_momentum in zip(current_momentums,
+ proposed_momentums):
+ axis = math_ops.range(independent_chain_ndims,
+ array_ops.rank(current_momentum))
+ lk0.append(_log_sum_sq(current_momentum, axis))
+ lk1.append(_log_sum_sq(proposed_momentum, axis))
+
+ lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1),
+ axis=-1)
+ lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1),
+ axis=-1)
+ lp0 = -current_target_log_prob # log_potential
+ lp1 = -proposed_target_log_prob # proposed_log_potential
+ x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)],
+ axis=-1)
+
+ # The sum is NaN if any element is NaN or we see both +Inf and -Inf.
+ # Thus we will replace such rows with infinite energy change which implies
+ # rejection. Recall that float-comparisons with NaN are always False.
+ is_sum_determinate = (
+ math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) &
+ math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1))
+ is_sum_determinate = array_ops.tile(
+ is_sum_determinate[..., array_ops.newaxis],
+ multiples=array_ops.concat([
+ array_ops.ones(array_ops.rank(is_sum_determinate),
+ dtype=dtypes.int32),
+ [4],
+ ], axis=0))
+ x = array_ops.where(is_sum_determinate,
+ x,
+ array_ops.fill(array_ops.shape(x),
+ value=x.dtype.as_numpy_dtype(np.inf)))
+
+ return math_ops.reduce_sum(x, axis=-1)
+
+
+def _choose(is_accepted,
+ accepted,
+ rejected,
+ independent_chain_ndims,
+ name=None):
+ """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where."""
+ def _expand_is_accepted_like(x):
+ with ops.name_scope("_choose"):
+ expand_shape = array_ops.concat([
+ array_ops.shape(is_accepted),
+ array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)],
+ dtype=dtypes.int32),
+ ], axis=0)
+ multiples = array_ops.concat([
+ array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32),
+ array_ops.shape(x)[independent_chain_ndims:],
+ ], axis=0)
+ m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape),
+ multiples)
+ m.set_shape(x.shape)
+ return m
+ with ops.name_scope(name, "_choose", values=[
+ is_accepted, accepted, rejected, independent_chain_ndims]):
+ return array_ops.where(_expand_is_accepted_like(accepted),
+ accepted,
+ rejected)
+
+
+def _maybe_call_fn_and_grads(fn,
+ fn_arg_list,
+ fn_result=None,
+ grads_fn_result=None,
+ description="target_log_prob"):
+ """Helper which computes `fn_result` and `grads` if needed."""
+ fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list)
+ else [fn_arg_list])
+ if fn_result is None:
+ fn_result = fn(*fn_arg_list)
+ if not fn_result.dtype.is_floating:
+ raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format(
+ description))
+ if grads_fn_result is None:
+ grads_fn_result = gradients_ops.gradients(
+ fn_result, fn_arg_list)
+ if len(fn_arg_list) != len(grads_fn_result):
+ raise ValueError("`{}` must be in one-to-one correspondence with "
+ "`grads_{}`".format(*[description]*2))
+ if any(g is None for g in grads_fn_result):
+ raise ValueError("Encountered `None` gradient.")
+ return fn_result, grads_fn_result
+
+
+def _prepare_args(target_log_prob_fn, state, step_size,
+ target_log_prob=None, grads_target_log_prob=None,
+ maybe_expand=False, description="target_log_prob"):
+ """Helper which processes input args to meet list-like assumptions."""
+ state_parts = list(state) if _is_list_like(state) else [state]
+ state_parts = [ops.convert_to_tensor(s, name="state")
+ for s in state_parts]
+ target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads(
+ target_log_prob_fn,
+ state_parts,
+ target_log_prob,
+ grads_target_log_prob,
+ description)
+ step_sizes = list(step_size) if _is_list_like(step_size) else [step_size]
+ step_sizes = [
+ ops.convert_to_tensor(
+ s, name="step_size", dtype=target_log_prob.dtype)
+ for s in step_sizes]
+ if len(step_sizes) == 1:
+ step_sizes *= len(state_parts)
+ if len(state_parts) != len(step_sizes):
+ raise ValueError("There should be exactly one `step_size` or it should "
+ "have same length as `current_state`.")
+ maybe_flatten = lambda x: x if maybe_expand or _is_list_like(state) else x[0]
+ return [
+ maybe_flatten(state_parts),
+ maybe_flatten(step_sizes),
+ target_log_prob,
+ grads_target_log_prob,
+ ]
+
+
+def _is_list_like(x):
+ """Helper which returns `True` if input is `list`-like."""
+ return isinstance(x, (tuple, list))
+
+
+def _log_sum_sq(x, axis=None):
+ """Computes log(sum(x**2))."""
+ return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis)