diff options
author | Ian Langmore <langmore@google.com> | 2018-02-08 12:15:32 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-08 12:18:33 -0800 |
commit | 5e10de83a4586c85bbbca313a878987f606fb00b (patch) | |
tree | 7bae66e8429bb38f5a72c7661aef7e6c65b765a8 /tensorflow/contrib/bayesflow | |
parent | 52d51aae52d978a66242a4a2f3342aab7e112443 (diff) |
Add effective_sample_size to tf.contrib.bayesflow.mcmc_diagnostics.
Also, start dealing with list args in a more regular manner.
PiperOrigin-RevId: 185032115
Diffstat (limited to 'tensorflow/contrib/bayesflow')
4 files changed, 393 insertions, 64 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 34156c28fe..8c856bb0b7 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -144,6 +144,7 @@ cuda_py_test( additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", + "//tensorflow/python:spectral_ops_test_util", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/python/ops/distributions", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py index 7652b6a7ce..d68fc9081a 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py @@ -22,11 +22,165 @@ import numpy as np from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test rng = np.random.RandomState(42) +class _EffectiveSampleSizeTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError( + "Subclass failed to implement `use_static_shape`.") + + def _check_versus_expected_effective_sample_size(self, + x_, + expected_ess, + sess, + atol=1e-2, + rtol=1e-2, + max_lags_threshold=None, + max_lags=None): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + ess = mcmc_diagnostics.effective_sample_size( + x, max_lags_threshold=max_lags_threshold, max_lags=max_lags) + if self.use_static_shape: + self.assertAllEqual(x.shape[1:], ess.shape) + + ess_ = sess.run(ess) + + self.assertAllClose( + np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol) + + def testIidRank1NormalHasFullEssMaxLags10(self): + # With a length 5000 iid normal sequence, and max_lags = 10, we should + # have a good estimate of ESS, and it should be close to the full sequence + # length of 5000. + # The choice of max_lags = 10 is a short cutoff, reasonable only since we + # know the correlation length should be zero right away. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000).astype(np.float32), + expected_ess=5000, + sess=sess, + max_lags=10, + rtol=0.3) + + def testIidRank2NormalHasFullEssMaxLags10(self): + # See similar test for Rank1Normal for reasoning. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000, 2).astype(np.float32), + expected_ess=5000, + sess=sess, + max_lags=10, + rtol=0.3) + + def testIidRank1NormalHasFullEssMaxLagThresholdZero(self): + # With a length 5000 iid normal sequence, and max_lags_threshold = 0, + # we should have a super-duper estimate of ESS, and it should be very close + # to the full sequence length of 5000. + # The choice of max_lags_cutoff = 0 means we cutoff as soon as the auto-corr + # is below zero. This should happen very quickly, due to the fact that the + # theoretical auto-corr is [1, 0, 0,...] + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000).astype(np.float32), + expected_ess=5000, + sess=sess, + max_lags_threshold=0., + rtol=0.1) + + def testIidRank2NormalHasFullEssMaxLagThresholdZero(self): + # See similar test for Rank1Normal for reasoning. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000, 2).astype(np.float32), + expected_ess=5000, + sess=sess, + max_lags_threshold=0., + rtol=0.1) + + def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self): + # Create x_, such that + # x_[i] = iid_x_[0], i = 0,...,9 + # x_[i] = iid_x_[1], i = 10,..., 19, + # and so on. + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=x_, + expected_ess=50000 // 10, + sess=sess, + max_lags=50, + rtol=0.2) + + def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero( + self): + # Create x_, such that + # x_[i] = iid_x_[0], i = 0,...,9 + # x_[i] = iid_x_[1], i = 10,..., 19, + # and so on. + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=x_, + expected_ess=50000 // 10, + sess=sess, + max_lags_threshold=0., + rtol=0.1) + + def testListArgs(self): + # x_ has correlation length 10 ==> ESS = N / 10 + # y_ has correlation length 1 ==> ESS = N + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + y_ = rng.randn(50000).astype(np.float32) + states = [x_, x_, y_, y_] + max_lags_threshold = [0., None, 0., None] + max_lags = [None, 5, None, 5] + + # See other tests for reasoning on tolerance. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + ess = mcmc_diagnostics.effective_sample_size( + states, + max_lags_threshold=max_lags_threshold, + max_lags=max_lags) + ess_ = sess.run(ess) + self.assertAllEqual(4, len(ess_)) + + self.assertAllClose(50000 // 10, ess_[0], rtol=0.3) + self.assertAllClose(50000 // 10, ess_[1], rtol=0.3) + self.assertAllClose(50000, ess_[2], rtol=0.1) + self.assertAllClose(50000, ess_[3], rtol=0.1) + + +class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest): + + @property + def use_static_shape(self): + return True + + +class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest): + + @property + def use_static_shape(self): + return False + + class _PotentialScaleReductionTest(object): @property @@ -48,7 +202,7 @@ class _PotentialScaleReductionTest(object): state_1 = rng.randn(n_samples, 3, 4) + offset rhat = mcmc_diagnostics.potential_scale_reduction( - state=[state_0, state_1], independent_chain_ndims=1) + chains_states=[state_0, state_1], independent_chain_ndims=1) self.assertIsInstance(rhat, list) with self.test_session() as sess: diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py index 5f3e6ade70..f3a645eafc 100644 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py +++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py @@ -25,6 +25,7 @@ from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + "effective_sample_size", "potential_scale_reduction", ] diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py index 3b6f92463e..bb8b915a9b 100644 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py @@ -14,6 +14,7 @@ # ============================================================================== """Utilities for Markov Chain Monte Carlo (MCMC) sampling. +@@effective_sample_size @@potential_scale_reduction """ @@ -21,20 +22,189 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.distributions.python.ops import sample_stats +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops __all__ = [ + "effective_sample_size", "potential_scale_reduction", ] -def potential_scale_reduction(state, independent_chain_ndims=1, name=None): +def effective_sample_size(states, + max_lags_threshold=None, + max_lags=None, + name=None): + """Estimate a lower bound on effective sample size for each independent chain. + + Roughly speaking, the "effective sample size" (ESS) is the size of an iid + sample with the same variance as `state`. + + More precisely, given a stationary sequence of possibly correlated random + variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number + such that + + ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.``` + + If the sequence is uncorrelated, `ESS = N`. In general, one should expect + `ESS <= N`, with more highly correlated sequences having smaller `ESS`. + + #### Example of using ESS to estimate standard error. + + ``` + tfd = tf.contrib.distributions + tfb = tf.contrib.bayesflow + + target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) + + # Get 1000 states from one chain. + states = tfb.hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target.log_prob, + current_state=tf.constant([0., 0.]), + step_size=0.05, + num_leapfrog_steps=20, + num_burnin_steps=200) + states.shape + ==> (1000, 2) + + ess = effective_sample_size(states) + ==> Shape (2,) Tensor + + mean, variance = tf.nn.moments(states, axis=0) + standard_error = tf.sqrt(variance / ess) + ``` + + Some math shows that, with `R_k` the auto-correlation sequence, + `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have + + ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]``` + + This function estimates the above by first estimating the auto-correlation. + Since `R_k` must be estimated using only `N - k` samples, it becomes + progressively noisier for larger `k`. For this reason, the summation over + `R_k` should be truncated at some number `max_lags < N`. Since many MCMC + methods generate chains where `R_k > 0`, a reasonable critera is to truncate + at the first index where the estimated auto-correlation becomes negative. + + Args: + states: `Tensor` or list of `Tensor` objects. Dimension zero should index + identically distributed states. + max_lags_threshold: `Tensor` or list of `Tensor` objects. + Must broadcast with `state`. The auto-correlation sequence is truncated + after the first appearance of a term less than `max_lags_threshold`. If + both `max_lags` and `max_lags_threshold` are `None`, + `max_lags_threshold` defaults to `0`. + max_lags: `Tensor` or list of `Tensor` objects. Must be `int`-like and + scalar valued. The auto-correlation sequence is truncated to this length. + May be provided only if `max_lags_threshold` is not. + name: `String` name to prepend to created ops. + + Returns: + ess: `Tensor` or list of `Tensor` objects. The effective sample size of + each component of `states`. Shape will be `states.shape[1:]`. + + Raises: + ValueError: If `states` and `max_lags_threshold` or `states` and `max_lags` + are both lists with different lengths. + """ + states_was_list = _is_list_like(states) + + # Convert all args to lists. + if not states_was_list: + states = [states] + + max_lags = _broadcast_maybelist_arg(states, max_lags, "max_lags") + max_lags_threshold = _broadcast_maybelist_arg(states, max_lags_threshold, + "max_lags_threshold") + + # Process items, one at a time. + with ops.name_scope(name, "effective_sample_size"): + ess_list = [ + _effective_sample_size_single_state(s, ml, mlt) + for (s, ml, mlt) in zip(states, max_lags, max_lags_threshold) + ] + + if states_was_list: + return ess_list + return ess_list[0] + + +def _effective_sample_size_single_state(states, max_lags, max_lags_threshold): + """ESS computation for one single Tensor argument.""" + if max_lags is not None and max_lags_threshold is not None: + raise ValueError( + "Expected at most one of max_lags, max_lags_threshold to be provided. " + "Found: {}, {}".format(max_lags, max_lags_threshold)) + + if max_lags_threshold is None: + max_lags_threshold = 0. + + with ops.name_scope( + "effective_sample_size_single_state", + values=[states, max_lags, max_lags_threshold]): + + states = ops.convert_to_tensor(states, name="states") + dt = states.dtype + + if max_lags is not None: + auto_corr = sample_stats.auto_correlation( + states, axis=0, max_lags=max_lags) + elif max_lags_threshold is not None: + max_lags_threshold = ops.convert_to_tensor( + max_lags_threshold, dtype=dt, name="max_lags_threshold") + auto_corr = sample_stats.auto_correlation(states, axis=0) + # Get a binary mask to zero out values of auto_corr below the threshold. + # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, + # mask[i, ...] = 0, otherwise. + # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] + # Building step by step, + # Assume auto_corr = [1, 0.5, 0.0, 0.3], and max_lags_threshold = 0.2. + # Step 1: mask = [False, False, True, False] + mask = auto_corr < max_lags_threshold + # Step 2: mask = [0, 0, 1, 1] + mask = math_ops.cast(mask, dtype=dt) + # Step 3: mask = [0, 0, 1, 2] + mask = math_ops.cumsum(mask, axis=0) + # Step 4: mask = [1, 1, 0, 0] + mask = math_ops.maximum(1. - mask, 0.) + auto_corr *= mask + else: + auto_corr = sample_stats.auto_correlation(states, axis=0) + + # With R[k] := auto_corr[k, ...], + # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} + # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) + # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} + #, where M is the max_lags truncation point chosen above. + + # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total + # ndims the same as auto_corr + n = _axis_size(states, axis=0) + k = math_ops.range(0., _axis_size(auto_corr, axis=0)) + nk_factor = (n - k) / n + if auto_corr.shape.ndims is not None: + new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) + else: + new_shape = array_ops.concat( + ([-1], + array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)), + axis=0) + nk_factor = array_ops.reshape(nk_factor, new_shape) + + return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0)) + + +def potential_scale_reduction(chains_states, + independent_chain_ndims=1, + name=None): """Gelman and Rubin's potential scale reduction factor for chain convergence. - Given `N > 1` samples from each of `C > 1` independent chains, the potential + Given `N > 1` states from each of `C > 1` independent chains, the potential scale reduction factor, commonly referred to as R-hat, measures convergence of the chains (to the same target) by testing for equality of means. Specifically, R-hat measures the degree to which variance (of the means) @@ -71,18 +241,18 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None): ==> (10, 2) # Get 1000 samples from the 10 independent chains. - state = tfb.hmc.sample_chain( + chains_states, _ = tfb.hmc.sample_chain( num_results=1000, target_log_prob_fn=target.log_prob, current_state=initial_state, step_size=0.05, num_leapfrog_steps=20, num_burnin_steps=200) - state.shape + chains_states.shape ==> (1000, 10, 2) rhat = tfb.mcmc_diagnostics.potential_scale_reduction( - state, independent_chain_ndims=1) + chains_states, independent_chain_ndims=1) # The second dimension needed a longer burn-in. rhat.eval() @@ -108,9 +278,9 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None): Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4. Args: - state: `Tensor` or Python `list` of `Tensor`s representing the state(s) of - a Markov Chain at each result step. The `ith` state is assumed to have - shape `[Ni, Ci1, Ci2,...,CiD] + A`. + chains_states: `Tensor` or Python `list` of `Tensor`s representing the + state(s) of a Markov Chain at each result step. The `ith` state is + assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain. Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent chains to be tested for convergence to the same target. @@ -129,6 +299,10 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None): Raises: ValueError: If `independent_chain_ndims < 1`. """ + chains_states_was_list = _is_list_like(chains_states) + if not chains_states_was_list: + chains_states = [chains_states] + # tensor_util.constant_value returns None iff a constant value (as a numpy # array) is not efficiently computable. Therefore, we try constant_value then # check for None. @@ -140,66 +314,53 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None): raise ValueError( "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format( independent_chain_ndims)) - with ops.name_scope( - name, - "potential_scale_reduction", - values=[state, independent_chain_ndims]): - if _is_list_like(state): - return [ - _potential_scale_reduction_single_state(s, independent_chain_ndims) - for s in state - ] - return _potential_scale_reduction_single_state(state, - independent_chain_ndims) + + with ops.name_scope(name, "potential_scale_reduction"): + rhat_list = [ + _potential_scale_reduction_single_state(s, independent_chain_ndims) + for s in chains_states + ] + + if chains_states_was_list: + return rhat_list + return rhat_list[0] def _potential_scale_reduction_single_state(state, independent_chain_ndims): """potential_scale_reduction for one single state `Tensor`.""" - # We assume exactly one leading dimension indexes e.g. correlated samples from - # each Markov chain. - state = ops.convert_to_tensor(state, name="state") - sample_ndims = 1 - - sample_axis = math_ops.range(0, sample_ndims) - chain_axis = math_ops.range(sample_ndims, - sample_ndims + independent_chain_ndims) - sample_and_chain_axis = math_ops.range(0, - sample_ndims + independent_chain_ndims) - - n = _axis_size(state, sample_axis) - m = _axis_size(state, chain_axis) - - # In the language of [2], - # B / n is the between chain variance, the variance of the chain means. - # W is the within sequence variance, the mean of the chain variances. - b_div_n = _reduce_variance( - math_ops.reduce_mean(state, sample_axis, keepdims=True), - sample_and_chain_axis, - biased=False) - w = math_ops.reduce_mean( - _reduce_variance(state, sample_axis, keepdims=True, biased=True), - sample_and_chain_axis) - - # sigma^2_+ is an estimate of the true variance, which would be unbiased if - # each chain was drawn from the target. c.f. "law of total variance." - sigma_2_plus = w + b_div_n - - return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n) - - -def effective_sample_size(state, - independent_chain_ndims=1, - max_lags=None, - max_lags_threshold=None, - name="effective_sample_size"): - if max_lags is not None and max_lags_threshold is not None: - raise ValueError( - "Expected at most one of max_lags, max_lags_threshold to be provided. " - "Found: {}, {}".format(max_lags, max_lags_threshold)) with ops.name_scope( - name, - values=[state, independent_chain_ndims, max_lags, max_lags_threshold]): - pass + "potential_scale_reduction_single_state", + values=[state, independent_chain_ndims]): + # We assume exactly one leading dimension indexes e.g. correlated samples + # from each Markov chain. + state = ops.convert_to_tensor(state, name="state") + sample_ndims = 1 + + sample_axis = math_ops.range(0, sample_ndims) + chain_axis = math_ops.range(sample_ndims, + sample_ndims + independent_chain_ndims) + sample_and_chain_axis = math_ops.range( + 0, sample_ndims + independent_chain_ndims) + + n = _axis_size(state, sample_axis) + m = _axis_size(state, chain_axis) + + # In the language of [2], + # B / n is the between chain variance, the variance of the chain means. + # W is the within sequence variance, the mean of the chain variances. + b_div_n = _reduce_variance( + math_ops.reduce_mean(state, sample_axis, keepdims=True), + sample_and_chain_axis, + biased=False) + w = math_ops.reduce_mean( + _reduce_variance(state, sample_axis, keepdims=True, biased=True), + sample_and_chain_axis) + + # sigma^2_+ is an estimate of the true variance, which would be unbiased if + # each chain was drawn from the target. c.f. "law of total variance." + sigma_2_plus = w + b_div_n + + return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n) # TODO(b/72873233) Move some variant of this to sample_stats. @@ -226,3 +387,15 @@ def _axis_size(x, axis=None): def _is_list_like(x): """Helper which returns `True` if input is `list`-like.""" return isinstance(x, (tuple, list)) + + +def _broadcast_maybelist_arg(states, secondary_arg, name): + """Broadcast a listable secondary_arg to that of states.""" + if _is_list_like(secondary_arg): + if len(secondary_arg) != len(states): + raise ValueError("Argument `%s` was a list of different length ({}) than " + "`states` ({})".format(name, len(states))) + else: + secondary_arg = [secondary_arg] * len(states) + + return secondary_arg |