diff options
author | Dustin Tran <trandustin@google.com> | 2018-02-22 20:02:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-22 20:05:59 -0800 |
commit | bff1648a179aa522fb13e2eb1b26f8464da26af6 (patch) | |
tree | bb8ac7ed07589bc8835c6892eefa81ae6efd1ae8 /tensorflow/contrib/bayesflow | |
parent | befd8234e1c209b26457eb5df37d2952004bdeaf (diff) |
Unify metropolis_hastings interface with HMC kernel.
PiperOrigin-RevId: 186716023
Diffstat (limited to 'tensorflow/contrib/bayesflow')
5 files changed, 558 insertions, 297 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index 5bd834e562..819095a060 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -224,12 +224,13 @@ class HMCTest(test.TestCase): expected_exp_x = self._shape_param / self._rate_param - acceptance_probs_, samples_, expected_x_ = sess.run( - [kernel_results.acceptance_probs, samples, expected_x], + log_accept_ratio_, samples_, expected_x_ = sess.run( + [kernel_results.log_accept_ratio, samples, expected_x], feed_dict) actual_x = samples_.mean() actual_exp_x = np.exp(samples_).mean() + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( expected_x_, expected_exp_x)) @@ -237,10 +238,10 @@ class HMCTest(test.TestCase): actual_x, actual_exp_x)) self.assertNear(actual_x, expected_x_, 2e-2) self.assertNear(actual_exp_x, expected_exp_x, 2e-2) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ > 0.5) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ <= 1.) + self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), + acceptance_probs > 0.5) + self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), + acceptance_probs <= 1.) def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): with self.test_session(graph=ops.Graph()) as sess: @@ -265,7 +266,7 @@ class HMCTest(test.TestCase): -x - x**2, # Non-constant gradient. array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) # This log_prob has the property that it is likely to attract - # the HMC flow toward, and below, zero...but for x <=0, + # the flow toward, and below, zero...but for x <=0, # log_prob(x) = -inf, which should result in rejection, as well # as a non-finite log_prob. Thus, this distribution gives us an opportunity # to test out the kernel results ability to correctly capture rejections due @@ -305,11 +306,10 @@ class HMCTest(test.TestCase): self.assertLess(0, neg_inf_mask.sum()) # We better have some rejections due to something other than -inf. self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) - # We better have been accepted a decent amount, even near the end of the - # chain, or else this HMC run just got stuck at some point. + # We better have accepted a decent amount, even near end of the chain. self.assertLess( 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) - # We better not have any NaNs in proposed state or log_prob. + # We better not have any NaNs in states or log_prob. # We may have some NaN in grads, which involve multiplication/addition due # to gradient rules. This is the known "NaN grad issue with tf.where." self.assertAllEqual(np.zeros_like(states_), @@ -333,9 +333,11 @@ class HMCTest(test.TestCase): np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) # Acceptance probs are zero whenever proposed state is negative. + acceptance_probs = np.exp(np.minimum( + kernel_results_.log_accept_ratio, 0.)) self.assertAllEqual( np.zeros_like(pstates_[neg_inf_mask]), - kernel_results_.acceptance_probs[neg_inf_mask]) + acceptance_probs[neg_inf_mask]) # The move is accepted ==> state = proposed state. self.assertAllEqual( @@ -383,26 +385,28 @@ class HMCTest(test.TestCase): seed=44) [ - acceptance_probs_, - bad_acceptance_probs_, + log_accept_ratio_, + bad_log_accept_ratio_, initial_draws_, updated_draws_, fake_draws_, ] = sess.run([ - kernel_results.acceptance_probs, - bad_kernel_results.acceptance_probs, + kernel_results.log_accept_ratio, + bad_kernel_results.log_accept_ratio, initial_draws, sample, bad_sample, ], feed_dict) # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_.mean(), 0.5) + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) + bad_acceptance_probs = np.exp(np.minimum(bad_log_accept_ratio_, 0.)) + self.assertGreater(acceptance_probs.mean(), 0.5) + self.assertGreater(bad_acceptance_probs.mean(), 0.5) # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_.mean(), 0.99) - self.assertLess(bad_acceptance_probs_.mean(), 0.99) + self.assertLess(acceptance_probs.mean(), 0.99) + self.assertLess(bad_acceptance_probs.mean(), 0.99) _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), updated_draws_.flatten()) @@ -410,9 +414,9 @@ class HMCTest(test.TestCase): fake_draws_.flatten()) logging_ops.vlog(1, "acceptance rate for true target: {}".format( - acceptance_probs_.mean())) + acceptance_probs.mean())) logging_ops.vlog(1, "acceptance rate for fake target: {}".format( - bad_acceptance_probs_.mean())) + bad_acceptance_probs.mean())) logging_ops.vlog(1, "K-S p-value for true target: {}".format( ks_p_value_true)) logging_ops.vlog(1, "K-S p-value for fake target: {}".format( @@ -615,15 +619,16 @@ class HMCTest(test.TestCase): step_size=2., num_leapfrog_steps=5, seed=46) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) + initial_x_, updated_x_, log_accept_ratio_ = sess.run( + [initial_x, updated_x, kernel_results.log_accept_ratio]) + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) + self.assertEqual(acceptance_probs, 0.) def testNanFromGradsDontPropagate(self): """Test that update with NaN gradients does not cause NaN in results.""" @@ -638,15 +643,16 @@ class HMCTest(test.TestCase): step_size=2., num_leapfrog_steps=5, seed=47) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) + initial_x_, updated_x_, log_accept_ratio_ = sess.run( + [initial_x, updated_x, kernel_results.log_accept_ratio]) + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) + self.assertEqual(acceptance_probs, 0.) self.assertAllFinite( gradients_ops.gradients(updated_x, initial_x)[0].eval()) @@ -671,10 +677,10 @@ class HMCTest(test.TestCase): step_size=0.01, num_leapfrog_steps=10, seed=48) - states_, acceptance_probs_ = sess.run( - [states, kernel_results.acceptance_probs]) + states_, log_accept_ratio_ = sess.run( + [states, kernel_results.log_accept_ratio]) self.assertEqual(dtype, states_.dtype) - self.assertEqual(dtype, acceptance_probs_.dtype) + self.assertEqual(dtype, log_accept_ratio_.dtype) def testChainWorksIn64Bit(self): self._testChainWorksDtype(np.float64) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py index 63d93fad64..f508e5b114 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py @@ -12,34 +12,195 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for metropolis_hastings.py.""" +"""Tests for Metropolis-Hastings.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np + from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh +from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -class McmcStepTest(test.TestCase): +class MetropolisHastingsTest(test.TestCase): + + def testKernelStateTensor(self): + """Test that transition kernel works with tensor input to `state`.""" + loc = variable_scope.get_variable("loc", initializer=0.) + + def target_log_prob_fn(loc): + return normal_lib.Normal(loc=0.0, scale=0.1).log_prob(loc) + + new_state, _ = mh.kernel( + target_log_prob_fn=target_log_prob_fn, + proposal_fn=mh.proposal_normal(scale=0.05), + current_state=loc, + seed=231251) + loc_update = loc.assign(new_state) + + init = variables.initialize_all_variables() + with self.test_session() as sess: + sess.run(init) + loc_samples = [] + for _ in range(2500): + loc_sample = sess.run(loc_update) + loc_samples.append(loc_sample) + loc_samples = loc_samples[500:] # drop samples for burn-in + + self.assertAllClose(np.mean(loc_samples), 0.0, rtol=1e-5, atol=1e-1) + self.assertAllClose(np.std(loc_samples), 0.1, rtol=1e-5, atol=1e-1) + + def testKernelStateList(self): + """Test that transition kernel works with list input to `state`.""" + num_chains = 2 + loc_one = variable_scope.get_variable( + "loc_one", [num_chains], + initializer=init_ops.zeros_initializer()) + loc_two = variable_scope.get_variable( + "loc_two", [num_chains], initializer=init_ops.zeros_initializer()) + + def target_log_prob_fn(loc_one, loc_two): + loc = array_ops.stack([loc_one, loc_two]) + log_prob = mvn_tril_lib.MultivariateNormalTriL( + loc=constant_op.constant([0., 0.]), + scale_tril=constant_op.constant([[0.1, 0.1], [0.0, 0.1]])).log_prob( + loc) + return math_ops.reduce_sum(log_prob, 0) + + def proposal_fn(loc_one, loc_two): + loc_one_proposal = mh.proposal_normal(scale=0.05) + loc_two_proposal = mh.proposal_normal(scale=0.05) + loc_one_sample, _ = loc_one_proposal(loc_one) + loc_two_sample, _ = loc_two_proposal(loc_two) + return [loc_one_sample, loc_two_sample], None + + new_state, _ = mh.kernel( + target_log_prob_fn=target_log_prob_fn, + proposal_fn=proposal_fn, + current_state=[loc_one, loc_two], + seed=12415) + loc_one_update = loc_one.assign(new_state[0]) + loc_two_update = loc_two.assign(new_state[1]) + + init = variables.initialize_all_variables() + with self.test_session() as sess: + sess.run(init) + loc_one_samples = [] + loc_two_samples = [] + for _ in range(10000): + loc_one_sample, loc_two_sample = sess.run( + [loc_one_update, loc_two_update]) + loc_one_samples.append(loc_one_sample) + loc_two_samples.append(loc_two_sample) + + loc_one_samples = np.array(loc_one_samples) + loc_two_samples = np.array(loc_two_samples) + loc_one_samples = loc_one_samples[1000:] # drop samples for burn-in + loc_two_samples = loc_two_samples[1000:] # drop samples for burn-in + + self.assertAllClose(np.mean(loc_one_samples, 0), + np.array([0.] * num_chains), + rtol=1e-5, atol=1e-1) + self.assertAllClose(np.mean(loc_two_samples, 0), + np.array([0.] * num_chains), + rtol=1e-5, atol=1e-1) + self.assertAllClose(np.std(loc_one_samples, 0), + np.array([0.1] * num_chains), + rtol=1e-5, atol=1e-1) + self.assertAllClose(np.std(loc_two_samples, 0), + np.array([0.1] * num_chains), + rtol=1e-5, atol=1e-1) + + def testKernelResultsUsingTruncatedDistribution(self): + def log_prob(x): + return array_ops.where( + x >= 0., + -x - x**2, + array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) + # The truncated distribution has the property that it is likely to attract + # the flow toward, and below, zero...but for x <=0, + # log_prob(x) = -inf, which should result in rejection, as well + # as a non-finite log_prob. Thus, this distribution gives us an opportunity + # to test out the kernel results ability to correctly capture rejections due + # to finite AND non-finite reasons. + + num_results = 1000 + # Large step size, will give rejections due to going into a region of + # log_prob = -inf. + step_size = 0.3 + num_chains = 2 + + with self.test_session(graph=ops.Graph()) as sess: + + # Start multiple independent chains. + initial_state = ops.convert_to_tensor([0.1] * num_chains) - def test_density_increasing_step_accepted(self): + states = [] + is_accepted = [] + proposed_states = [] + current_state = initial_state + for _ in range(num_results): + current_state, kernel_results = mh.kernel( + target_log_prob_fn=log_prob, + proposal_fn=mh.proposal_uniform(step_size=step_size), + current_state=current_state, + seed=42) + states.append(current_state) + proposed_states.append(kernel_results.proposed_state) + is_accepted.append(kernel_results.is_accepted) + + states = array_ops.stack(states) + proposed_states = array_ops.stack(proposed_states) + is_accepted = array_ops.stack(is_accepted) + states_, pstates_, is_accepted_ = sess.run( + [states, proposed_states, is_accepted]) + + # We better have accepted a decent amount, even near end of the chain. + self.assertLess( + 0.1, is_accepted_[int(0.9 * num_results):].mean()) + # We better not have any NaNs in states. + self.assertAllEqual(np.zeros_like(states_), + np.isnan(states_)) + # We better not have any +inf in states. + self.assertAllEqual(np.zeros_like(states_), + np.isposinf(states_)) + + # The move is accepted ==> state = proposed state. + self.assertAllEqual( + states_[is_accepted_], + pstates_[is_accepted_], + ) + + # The move was rejected <==> state[t] == state[t - 1]. + for t in range(1, num_results): + for i in range(num_chains): + if is_accepted_[t, i]: + self.assertNotEqual(states_[t, i], states_[t - 1, i]) + else: + self.assertEqual(states_[t, i], states_[t - 1, i]) + + def testDensityIncreasingStepAccepted(self): """Tests that if a transition increases density, it is always accepted.""" target_log_density = lambda x: - x * x - state = variable_scope.get_variable('state', initializer=10.) + state = variable_scope.get_variable("state", initializer=10.) state_log_density = variable_scope.get_variable( - 'state_log_density', + "state_log_density", initializer=target_log_density(state.initialized_value())) log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) + "log_accept_ratio", initializer=0.) get_next_proposal = lambda x: (x - 1., None) step = mh.evolve(state, state_log_density, log_accept_ratio, @@ -54,7 +215,7 @@ class McmcStepTest(test.TestCase): self.assertAlmostEqual(sample, 9 - j) self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j)) - def test_sample_properties(self): + def testSampleProperties(self): """Tests that the samples converge to the target distribution.""" def target_log_density(x): @@ -62,16 +223,16 @@ class McmcStepTest(test.TestCase): return - (x - 2.0) * (x - 2.0) * 0.5 # Use the uniform random walker to generate proposals. - proposal_fn = mh.uniform_random_proposal( + proposal_fn = mh.proposal_uniform( step_size=1.0, seed=1234) - state = variable_scope.get_variable('state', initializer=0.0) + state = variable_scope.get_variable("state", initializer=0.0) state_log_density = variable_scope.get_variable( - 'state_log_density', + "state_log_density", initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) + "log_accept_ratio", initializer=0.) + # Random walk MCMC converges slowly so need to put in enough iterations. num_iterations = 5000 step = mh.evolve(state, state_log_density, log_accept_ratio, @@ -98,11 +259,11 @@ class McmcStepTest(test.TestCase): self.assertAlmostEqual(sample_mean, 2.0, delta=0.1) self.assertAlmostEqual(sample_variance, 1.0, delta=0.1) - def test_normal_proposals(self): + def testProposalNormal(self): """Tests that the normal proposals are correctly distributed.""" initial_points = array_ops.ones([10000], dtype=dtypes.float32) - proposal_fn = mh.normal_random_proposal( + proposal_fn = mh.proposal_normal( scale=2.0, seed=1234) proposal_points, _ = proposal_fn(initial_points) @@ -115,7 +276,7 @@ class McmcStepTest(test.TestCase): self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1) self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1) - def test_docstring_example(self): + def testDocstringExample(self): """Tests the simplified docstring example with multiple chains.""" n = 2 # dimension of the problem @@ -123,7 +284,7 @@ class McmcStepTest(test.TestCase): # Generate 300 initial values randomly. Each of these would be an # independent starting point for a Markov chain. state = variable_scope.get_variable( - 'state', initializer=random_ops.random_normal( + "state", initializer=random_ops.random_normal( [300, n], mean=3.0, dtype=dtypes.float32, seed=42)) # Computes the log(p(x)) for the unit normal density and ignores the @@ -133,12 +294,12 @@ class McmcStepTest(test.TestCase): # Initial log-density value state_log_density = variable_scope.get_variable( - 'state_log_density', + "state_log_density", initializer=log_density(state.initialized_value())) # A variable to store the log_acceptance_ratio: log_acceptance_ratio = variable_scope.get_variable( - 'log_acceptance_ratio', + "log_acceptance_ratio", initializer=array_ops.zeros([300], dtype=dtypes.float32)) # Generates random proposals by moving each coordinate uniformly and @@ -175,5 +336,5 @@ class McmcStepTest(test.TestCase): - np.reshape(covariance, [n**2]))), 0, delta=0.2) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 9e45c19411..82693c2b7b 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -46,15 +46,13 @@ __all__ = [ KernelResults = collections.namedtuple( "KernelResults", [ - "acceptance_probs", + "log_accept_ratio", "current_grads_target_log_prob", # "Current result" means "accepted". "current_target_log_prob", # "Current result" means "accepted". - "energy_change", "is_accepted", "proposed_grads_target_log_prob", "proposed_state", "proposed_target_log_prob", - "random_positive", ]) @@ -63,15 +61,13 @@ def _make_dummy_kernel_results( dummy_target_log_prob, dummy_grads_target_log_prob): return KernelResults( - acceptance_probs=dummy_target_log_prob, + log_accept_ratio=dummy_target_log_prob, current_grads_target_log_prob=dummy_grads_target_log_prob, current_target_log_prob=dummy_target_log_prob, - energy_change=dummy_target_log_prob, is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), proposed_grads_target_log_prob=dummy_grads_target_log_prob, proposed_state=dummy_state, proposed_target_log_prob=dummy_target_log_prob, - random_positive=dummy_target_log_prob, ) @@ -244,7 +240,7 @@ def sample_chain( Default value: `None` (i.e., "hmc_sample_chain"). Returns: - accepted_states: Tensor or Python list of `Tensor`s representing the + next_states: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. kernel_results: `collections.namedtuple` of internal calculations used to @@ -470,7 +466,7 @@ def sample_annealed_importance_chain( Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). Returns: - accepted_state: `Tensor` or Python list of `Tensor`s representing the + next_state: `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input `current_state`. ais_weights: Tensor with the estimated weight(s). Has shape matching @@ -591,18 +587,19 @@ def kernel(target_log_prob_fn, target = tfd.Normal(loc=dtype(0), scale=dtype(1)) - new_x, other_results = hmc.kernel( + next_x, other_results = hmc.kernel( target_log_prob_fn=target.log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=3)[:4] - x_update = x.assign(new_x) + x_update = x.assign(next_x) step_size_update = step_size.assign_add( step_size * tf.where( - other_results.acceptance_probs > target_accept_rate, - 0.01, -0.01)) + tf.exp(tf.minimum(other_results.log_accept_ratio), 0.) > + target_accept_rate, + 0.01, -0.01)) warmup = tf.group([x_update, step_size_update]) @@ -753,7 +750,7 @@ def kernel(target_log_prob_fn, Default value: `None` (i.e., "hmc_kernel"). Returns: - accepted_state: Tensor or Python list of `Tensor`s representing the state(s) + next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to @@ -806,30 +803,27 @@ def kernel(target_log_prob_fn, proposed_target_log_prob, proposed_momentums, independent_chain_ndims) + log_accept_ratio = -energy_change - # u < exp(min(-energy, 0)), where u~Uniform[0,1) - # ==> -log(u) >= max(e, 0) - # ==> -log(u) >= e - # (Perhaps surprisingly, we don't have a better way to obtain a random - # uniform from positive reals, i.e., `tf.random_uniform(minval=0, - # maxval=np.inf)` won't work.) - random_uniform = random_ops.random_uniform( + # u < exp(log_accept_ratio), where u~Uniform[0,1) + # ==> log(u) < log_accept_ratio + random_value = random_ops.random_uniform( shape=array_ops.shape(energy_change), dtype=energy_change.dtype, seed=seed) - random_positive = -math_ops.log(random_uniform) - is_accepted = random_positive >= energy_change + random_negative = math_ops.log(random_value) + is_accepted = random_negative < log_accept_ratio accepted_target_log_prob = array_ops.where(is_accepted, proposed_target_log_prob, current_target_log_prob) - accepted_state_parts = [_choose(is_accepted, - proposed_state_part, - current_state_part, - independent_chain_ndims) - for current_state_part, proposed_state_part - in zip(current_state_parts, proposed_state_parts)] + next_state_parts = [_choose(is_accepted, + proposed_state_part, + current_state_part, + independent_chain_ndims) + for current_state_part, proposed_state_part + in zip(current_state_parts, proposed_state_parts)] accepted_grads_target_log_prob = [ _choose(is_accepted, @@ -841,17 +835,15 @@ def kernel(target_log_prob_fn, maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] return [ - maybe_flatten(accepted_state_parts), + maybe_flatten(next_state_parts), KernelResults( - acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), + log_accept_ratio=log_accept_ratio, current_grads_target_log_prob=accepted_grads_target_log_prob, current_target_log_prob=accepted_target_log_prob, - energy_change=energy_change, is_accepted=is_accepted, proposed_grads_target_log_prob=proposed_grads_target_log_prob, proposed_state=maybe_flatten(proposed_state_parts), proposed_target_log_prob=proposed_target_log_prob, - random_positive=random_positive, ), ] @@ -883,8 +875,8 @@ def _leapfrog_integrator(current_momentums, momentum = tf.placeholder(np.float32) [ - new_momentums, - new_positions, + next_momentums, + next_positions, ] = hmc._leapfrog_integrator( current_momentums=[momentum], target_log_prob_fn=tfd.MultivariateNormalDiag( @@ -901,7 +893,7 @@ def _leapfrog_integrator(current_momentums, positions = np.zeros([num_iter, dims], dtype) for i in xrange(num_iter): position_, momentum_ = sess.run( - [new_momentums[0], new_position[0]], + [next_momentums[0], next_position[0]], feed_dict={position: position_, momentum: momentum_}) positions[i] = position_ @@ -944,9 +936,9 @@ def _leapfrog_integrator(current_momentums, state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `accepted_state`. + `target_log_prob_fn` at `next_state`. proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt - `accepted_state`. + `next_state`. Raises: ValueError: if `len(momentums) != len(state_parts)`. @@ -1066,8 +1058,8 @@ def _compute_energy_change(current_target_log_prob, axis=-1) lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), axis=-1) - lp0 = -current_target_log_prob # log_potential - lp1 = -proposed_target_log_prob # proposed_log_potential + lp0 = -current_target_log_prob # potential + lp1 = -proposed_target_log_prob # proposed_potential x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], axis=-1) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py index 7bdeaa862d..e7fcbc65ef 100644 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py +++ b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py @@ -25,9 +25,10 @@ from tensorflow.contrib.bayesflow.python.ops.metropolis_hastings_impl import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + 'kernel', 'evolve', - 'uniform_random_proposal', - 'normal_random_proposal', + 'proposal_uniform', + 'proposal_normal', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py index dc1ac68ce0..05aa134ed5 100644 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to create a Markov Chain Monte Carlo Metropolis step. +"""Metropolis-Hastings and proposal distributions. +@@kernel @@evolve -@@uniform_random_proposal -@@normal_random_proposal +@@proposal_uniform +@@proposal_normal """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -31,123 +34,198 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops __all__ = [ - 'evolve', - 'uniform_random_proposal', - 'normal_random_proposal', + "kernel", + "evolve", + "proposal_uniform", + "proposal_normal", ] -def _single_iteration(current_state, current_log_density, - log_unnormalized_prob_fn, proposal_fn, seed=None, - name='None'): - """Performs a single Metropolis-Hastings step. +KernelResults = collections.namedtuple( + "KernelResults", + [ + "log_accept_ratio", + "current_target_log_prob", # "Current result" means "accepted". + "is_accepted", + "proposed_state", + ]) + + +def kernel(target_log_prob_fn, + proposal_fn, + current_state, + seed=None, + current_target_log_prob=None, + name=None): + """Runs the Metropolis-Hastings transition kernel. + + This function can update multiple chains in parallel. It assumes that all + leftmost dimensions of `current_state` index independent chain states (and are + therefore updated independently). The output of `target_log_prob_fn()` should + sum log-probabilities across all event dimensions. Slices along the rightmost + dimensions may have different target distributions; for example, + `current_state[0, :]` could have a different target distribution from + `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of + independent chains is `tf.size(target_log_prob_fn(*current_state))`.) Args: - current_state: Float-like `Tensor` (i.e., `dtype` is either - `tf.float16`, `tf.float32` or `tf.float64`) of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` - callables. - current_log_density: Float-like `Tensor` with `dtype` and shape equivalent - to `log_unnormalized_prob_fn(current_state)`, i.e., matching the result of - `log_unnormalized_prob_fn` invoked at `current_state`. - log_unnormalized_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, where `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair is a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric (e.g., random walk, where the proposal is either - normal or uniform centered at `current_state`), i.e., - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to `None` instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: Python `str` name prefix for ops managed by this function. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + proposal_fn: Python callable which takes an argument like `current_state` + (or `*current_state` if it's a list) and returns a tuple of proposed + states of same shape as `state`, and a log ratio `Tensor` of same shape + as `current_target_log_prob`. The log ratio is the log-probability of + `state` given proposed states minus the log-probability of proposed + states given `state`. If the proposal is symmetric, set the second value + to `None`: this enables more efficient computation than explicitly + supplying a tensor of zeros. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to + specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + name: A name of the operation (optional). Returns: - next_state: `Tensor` with `dtype` and shape matching `current_state`. - Created by propagating the chain by one step, starting from + next_state: Tensor or Python list of `Tensor`s representing the state(s) + of the Markov chain(s) at each result step. Has same shape as `current_state`. - next_log_density: `Tensor` with `dtype` and shape matching - `current_log_density`, which is equal to the value of the unnormalized - `log_unnormalized_prob_fn` computed at `next_state`. - log_accept_ratio: `Tensor` with `dtype` and shape matching - `current_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio used in generating the `next_state`. - """ + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. - with ops.name_scope(name, 'single_iteration', [current_state]): - # The proposed state and the log of the corresponding Hastings ratio. - proposal_state, log_transit_ratio = proposal_fn(current_state) - - # If the log ratio is None, assume that the transitions are symmetric, - # i.e., Prob(Current -> Proposed) = Prob(Proposed -> Current). - if log_transit_ratio is None: - log_transit_ratio = 0. - - # Log-density of the proposal state. - proposal_log_density = log_unnormalized_prob_fn(proposal_state) - - # Ops to compute the log of the acceptance ratio. Recall that the - # acceptance ratio is: [Prob(Proposed) / Prob(Current)] * - # [Prob(Proposed -> Current) / Prob(Current -> Proposed)]. The log of the - # second term is the log_transit_ratio. - with ops.name_scope('accept_reject'): - # The log of the acceptance ratio. - log_accept_ratio = (proposal_log_density - current_log_density - + log_transit_ratio) - - # A proposal is accepted or rejected depending on the acceptance ratio. - # If the acceptance ratio is greater than 1 then it is always accepted. - # If the acceptance ratio is less than 1 then the proposal is accepted - # with probability = acceptance ratio. As we are working in log space to - # prevent over/underflows, this logic is expressed in log terms below. - # If a proposal is accepted we place a True in the acceptance state - # tensor and if it is to be rejected we place a False. - # The log_draws below have to be compared to the log_accept_ratio so we - # make sure that they have the same data type. - log_draws = math_ops.log(random_ops.random_uniform( - array_ops.shape(current_log_density), seed=seed, - dtype=log_accept_ratio.dtype)) - is_proposal_accepted = log_draws < log_accept_ratio - - # The acceptance state decides which elements of the current state are to - # be replaced with the corresponding elements in the proposal state. - with ops.name_scope(name, 'metropolis_single_step', - [current_state, current_log_density]): - next_log_density = array_ops.where(is_proposal_accepted, - proposal_log_density, - current_log_density) - next_state = array_ops.where(is_proposal_accepted, proposal_state, - current_state) - - return next_state, next_log_density, log_accept_ratio + #### Examples + + We illustrate Metropolis-Hastings on a Normal likelihood with + unknown mean. + + ```python + tfd = tf.contrib.distributions + tfp = tf.contrib.bayesflow + + loc = tf.get_variable("loc", initializer=1.) + x = tf.constant([0.0] * 50) + + def make_target_log_prob_fn(x): + def target_log_prob_fn(loc): + prior = tfd.Normal(loc=0., scale=1.) + likelihood = tfd.Independent( + tfd.Normal(loc=loc, scale=0.1), + reinterpreted_batch_ndims=1) + return prior.log_prob(loc) + likelihood.log_prob(x) + return target_log_prob_fn + + next_state, kernel_results = tfp.metropolis_hastings.kernel( + target_log_prob_fn=make_target_log_prob_fn(x), + proposal_fn=tfp.metropolis_hastings.proposal_normal(), + current_state=loc) + loc_update = loc.assign(next_state) + ``` + + We illustrate Metropolis-Hastings on a Normal likelihood with + unknown mean and variance. We apply 4 chains. + + ```python + tfd = tf.contrib.distributions + tfp = tf.contrib.bayesflow + + num_chains = 4 + loc = tf.get_variable("loc", shape=[num_chains], + initializer=tf.random_normal_initializer()) + scale = tf.get_variable("scale", shape=[num_chains], + initializer=tf.ones_initializer()) + x = tf.constant([0.0] * 50) + + def make_target_log_prob_fn(x): + data = tf.reshape(x, shape=[-1, 1]) + def target_log_prob_fn(loc, scale): + prior_loc = tfd.Normal(loc=0., scale=1.) + prior_scale = tfd.InverseGamma(concentration=1., rate=1.) + likelihood = tfd.Independent( + tfd.Normal(loc=loc, scale=scale), + reinterpreted_batch_ndims=1) + return (prior_loc.log_prob(loc) + + prior_scale.log_prob(scale) + + likelihood.log_prob(data)) + return target_log_prob_fn + + def proposal_fn(loc, scale): + loc_proposal = tfp.metropolis_hastings.proposal_normal() + scale_proposal = tfp.metropolis_hastings.proposal_uniform(minval=-1.) + proposed_loc, _ = loc_proposal(loc) + proposed_scale, _ = scale_proposal(scale) + proposed_scale = tf.maximum(proposed_scale, 0.01) + return [proposed_loc, proposed_scale], None + + next_state, kernel_results = tfp.metropolis_hastings.kernel( + target_log_prob_fn=make_target_log_prob_fn(x), + proposal_fn=proposal_fn, + current_state=[loc, scale]) + train_op = tf.group(loc.assign(next_state[0]), + scale.assign(next_state[1])) + ``` + + """ + with ops.name_scope( + name, "metropolis_hastings_kernel", + [current_state, seed, current_target_log_prob]): + with ops.name_scope("initialize"): + maybe_expand = lambda x: list(x) if _is_list_like(x) else [x] + current_state_parts = maybe_expand(current_state) + if current_target_log_prob is None: + current_target_log_prob = target_log_prob_fn(*current_state_parts) + + proposed_state, log_transit_ratio = proposal_fn(*current_state_parts) + proposed_state_parts = maybe_expand(proposed_state) + + proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) + + with ops.name_scope( + "accept_reject", + [current_state_parts, proposed_state_parts, + current_target_log_prob, proposed_target_log_prob]): + log_accept_ratio = proposed_target_log_prob - current_target_log_prob + if log_transit_ratio is not None: + # If the log_transit_ratio is None, then assume the proposal is + # symmetric, i.e., + # log p(old | new) - log p(new | old) = 0. + log_accept_ratio += log_transit_ratio + + # u < exp(log_accept_ratio), where u~Uniform[0,1) + # ==> log(u) < log_accept_ratio + random_value = random_ops.random_uniform( + array_ops.shape(log_accept_ratio), + dtype=log_accept_ratio.dtype, + seed=seed) + random_negative = math_ops.log(random_value) + is_accepted = random_negative < log_accept_ratio + next_state_parts = [array_ops.where(is_accepted, + proposed_state_part, + current_state_part) + for proposed_state_part, current_state_part in + zip(proposed_state_parts, current_state_parts)] + accepted_log_prob = array_ops.where(is_accepted, + proposed_target_log_prob, + current_target_log_prob) + maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] + return [ + maybe_flatten(next_state_parts), + KernelResults( + log_accept_ratio=log_accept_ratio, + current_target_log_prob=accepted_log_prob, + is_accepted=is_accepted, + proposed_state=maybe_flatten(proposed_state_parts), + ), + ] def evolve(initial_sample, initial_log_density, initial_log_accept_ratio, - log_unnormalized_prob_fn, + target_log_prob_fn, proposal_fn, n_steps=1, seed=None, @@ -162,9 +240,11 @@ def evolve(initial_sample, The probability distribution may have an unknown normalization constan. We parameterize the probability density as follows: - ``` - f(x) = exp(L(x) + constant) - ``` + + ```none + f(x) = exp(L(x) + constant) + ``` + Here `L(x)` is any continuous function with an (possibly unknown but finite) upper bound, i.e. there exists a number beta such that `L(x)< beta < infinity` for all x. The constant is the normalization needed @@ -188,72 +268,77 @@ def evolve(initial_sample, The following example, demonstrates the use to generate a 1000 uniform random walk Metropolis samplers run in parallel for the normal target distribution. + ```python - n = 3 # dimension of the problem - - # Generate 1000 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = tf.get_variable( - 'state',initializer=tf.random_normal([1000, n], mean=3.0, - dtype=tf.float64, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = tf.get_variable( - 'state_log_density', initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = tf.get_variable( - 'log_acceptance_ratio', initializer=tf.zeros([1000], dtype=tf.float64)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = tf.initialize_all_variables() - with tf.Session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps and print out the mean across - # the chains every 100 iterations. - for n_iter in range(10): - # Executing the stepper advances the chain to the next state. - sess.run(stepper) - # Print out the current value of the mean(sample) for every dimension. - print(np.mean(sess.run(state), 0)) - # Estimated covariance matrix - samples = sess.run(state) - print('') - print(np.cov(samples, rowvar=False)) + n = 3 # dimension of the problem + + # Generate 1000 initial values randomly. Each of these would be an + # independent starting point for a Markov chain. + state = tf.get_variable( + "state", + initializer=tf.random_normal([1000, n], + mean=3.0, + dtype=tf.float64, + seed=42)) + + # Computes the log(p(x)) for the unit normal density and ignores the + # normalization constant. + def log_density(x): + return -tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 + + # Initial log-density value + state_log_density = tf.get_variable( + "state_log_density", + initializer=log_density(state.initialized_value())) + + # A variable to store the log_acceptance_ratio: + log_acceptance_ratio = tf.get_variable( + "log_acceptance_ratio", + initializer=tf.zeros([1000], dtype=tf.float64)) + + # Generates random proposals by moving each coordinate uniformly and + # independently in a box of size 2 centered around the current value. + # Returns the new point and also the log of the Hastings ratio (the + # ratio of the probability of going from the proposal to origin and the + # probability of the reverse transition). When this ratio is 1, the value + # may be omitted and replaced by None. + def random_proposal(x): + return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, + dtype=x.dtype, seed=12)), None + + # Create the op to propagate the chain for 100 steps. + stepper = mh.evolve( + state, state_log_density, log_acceptance_ratio, + log_density, random_proposal, n_steps=100, seed=123) + init = tf.initialize_all_variables() + with tf.Session() as sess: + sess.run(init) + # Run the chains for a total of 1000 steps and print out the mean across + # the chains every 100 iterations. + for n_iter in range(10): + # Executing the stepper advances the chain to the next state. + sess.run(stepper) + # Print out the current value of the mean(sample) for every dimension. + print(np.mean(sess.run(state), 0)) + # Estimated covariance matrix + samples = sess.run(state) + print(np.cov(samples, rowvar=False)) ``` Args: initial_sample: A float-like `tf.Variable` of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` + be consumed by the `target_log_prob_fn` and `proposal_fn` callables. initial_log_density: Float-like `tf.Variable` with `dtype` and shape - equivalent to `log_unnormalized_prob_fn(initial_sample)`, i.e., matching - the result of `log_unnormalized_prob_fn` invoked at `current_state`. + equivalent to `target_log_prob_fn(initial_sample)`, i.e., matching + the result of `target_log_prob_fn` invoked at `current_state`. initial_log_accept_ratio: A `tf.Variable` with `dtype` and shape matching `initial_log_density`. Stands for the log of Metropolis-Hastings acceptance ratio after propagating the chain for `n_steps`. - log_unnormalized_prob_fn: A Python callable evaluated at + target_log_prob_fn: A Python callable evaluated at `current_state` and returning a float-like `Tensor` of log target-density up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where + `target_log_prob_fn(x) = log(g(x))`, where `target_density = g(x)/Z` for some constant `A`. The shape of the input tensor is the same as the shape of the `current_state`. The shape of the output tensor is either @@ -265,7 +350,7 @@ def evolve(initial_sample, and the result must be of shape [B1, ..., Bb]. For example, if the distribution that is being sampled is a 10 dimensional normal, then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` + last dimension will then be 'consumed' by `target_log_prob_fn` and it should return tensors of shape [100] and [30, 20] respectively. proposal_fn: A callable accepting a real valued `Tensor` of current sample points and returning a tuple of two `Tensors`. The first element of the @@ -289,42 +374,48 @@ def evolve(initial_sample, forward_step: an `Op` to step the Markov chain forward for `n_steps`. """ - with ops.name_scope(name, 'metropolis_hastings', [initial_sample]): + with ops.name_scope(name, "metropolis_hastings", [initial_sample]): current_state = initial_sample - current_log_density = initial_log_density + current_target_log_prob = initial_log_density log_accept_ratio = initial_log_accept_ratio - # Stop condition for the while_loop - def stop_condition(i, _): - return i < n_steps - - def step(i, loop_vars): - """Wrap `_single_iteration` for `while_loop`.""" - state = loop_vars[0] - state_log_density = loop_vars[1] - return i + 1, list(_single_iteration(state, state_log_density, - log_unnormalized_prob_fn, - proposal_fn, seed=seed)) - - loop_vars = [current_state, current_log_density, log_accept_ratio] - # Build an `Op` to evolve the Markov chain for `n_steps` - (_, [end_state, end_log_density, end_log_acceptance]) = ( + def step(i, current_state, current_target_log_prob, log_accept_ratio): + """Wrap single Markov chain iteration in `while_loop`.""" + next_state, kernel_results = kernel( + target_log_prob_fn=target_log_prob_fn, + proposal_fn=proposal_fn, + current_state=current_state, + current_target_log_prob=current_target_log_prob, + seed=seed) + accepted_log_prob = kernel_results.current_target_log_prob + log_accept_ratio = kernel_results.log_accept_ratio + return i + 1, next_state, accepted_log_prob, log_accept_ratio + + (_, accepted_state, accepted_target_log_prob, accepted_log_accept_ratio) = ( control_flow_ops.while_loop( - stop_condition, step, - (0, loop_vars), - parallel_iterations=1, swap_memory=1)) + cond=lambda i, *ignored_args: i < n_steps, + body=step, + loop_vars=[ + 0, # i + current_state, + current_target_log_prob, + log_accept_ratio, + ], + parallel_iterations=1 if seed is not None else 10, + # TODO(b/73775595): Confirm optimal setting of swap_memory. + swap_memory=1)) forward_step = control_flow_ops.group( - state_ops.assign(current_log_density, end_log_density), - state_ops.assign(current_state, end_state), - state_ops.assign(log_accept_ratio, end_log_acceptance)) + state_ops.assign(current_target_log_prob, accepted_target_log_prob), + state_ops.assign(current_state, accepted_state), + state_ops.assign(log_accept_ratio, accepted_log_accept_ratio)) return forward_step -def uniform_random_proposal(step_size=1., - seed=None, - name=None): +def proposal_uniform(step_size=1., + seed=None, + name=None): """Returns a callable that adds a random uniform tensor to the input. This function returns a callable that accepts one `Tensor` argument of any @@ -346,11 +437,13 @@ def uniform_random_proposal(step_size=1., Returns: proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. + 2-tuple. The first value in the tuple is a `Tensor` of the same shape and + dtype as the input argument and the second element of the tuple is None. """ - with ops.name_scope(name, 'uniform_random_proposal', [step_size]): + with ops.name_scope(name, "proposal_uniform", [step_size]): + step_size = ops.convert_to_tensor(step_size, name="step_size") + def proposal_fn(input_state, name=None): """Adds a uniform perturbation to the input state. @@ -359,12 +452,12 @@ def uniform_random_proposal(step_size=1., name: A string that sets the name for this `Op`. Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching + proposal_state: A float-like `Tensor` with `dtype` and shape matching `input_state`. log_transit_ratio: `None`. Proposal is symmetric. """ - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') + with ops.name_scope(name, "proposer", [input_state]): + input_state = ops.convert_to_tensor(input_state, name="input_state") return input_state + random_ops.random_uniform( array_ops.shape(input_state), minval=-step_size, @@ -373,9 +466,9 @@ def uniform_random_proposal(step_size=1., return proposal_fn -def normal_random_proposal(scale=1., - seed=None, - name=None): +def proposal_normal(scale=1., + seed=None, + name=None): """Returns a callable that adds a random normal tensor to the input. This function returns a callable that accepts one `Tensor` argument of any @@ -398,11 +491,13 @@ def normal_random_proposal(scale=1., Returns: proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. + 2-tuple. The first value in the tuple is a `Tensor` of the same shape and + dtype as the input argument and the second element of the tuple is None. """ - with ops.name_scope(name, 'normal_random_proposal', [scale]): + with ops.name_scope(name, "proposal_normal", [scale]): + scale = ops.convert_to_tensor(scale, name="scale") + def proposal_fn(input_state, name=None): """Adds a normal perturbation to the input state. @@ -411,16 +506,22 @@ def normal_random_proposal(scale=1., name: A string that sets the name for this `Op`. Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching + proposal_state: A float-like `Tensor` with `dtype` and shape matching `input_state`. log_transit_ratio: `None`. Proposal is symmetric. """ - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') + with ops.name_scope(name, "proposer", [input_state]): + input_state = ops.convert_to_tensor(input_state, name="input_state") return input_state + random_ops.random_normal( array_ops.shape(input_state), mean=0., stddev=scale, + dtype=scale.dtype, seed=seed), None return proposal_fn + + +def _is_list_like(x): + """Helper which returns `True` if input is `list`-like.""" + return isinstance(x, (tuple, list)) |