aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2018-02-08 12:15:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-08 12:18:33 -0800
commit5e10de83a4586c85bbbca313a878987f606fb00b (patch)
tree7bae66e8429bb38f5a72c7661aef7e6c65b765a8 /tensorflow/contrib/bayesflow
parent52d51aae52d978a66242a4a2f3342aab7e112443 (diff)
Add effective_sample_size to tf.contrib.bayesflow.mcmc_diagnostics.
Also, start dealing with list args in a more regular manner. PiperOrigin-RevId: 185032115
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/BUILD1
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py156
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py1
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py299
4 files changed, 393 insertions, 64 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 34156c28fe..8c856bb0b7 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -144,6 +144,7 @@ cuda_py_test(
additional_deps = [
":bayesflow_py",
"//third_party/py/numpy",
+ "//tensorflow/python:spectral_ops_test_util",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python/ops/distributions",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py
index 7652b6a7ce..d68fc9081a 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py
@@ -22,11 +22,165 @@ import numpy as np
from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import spectral_ops_test_util
from tensorflow.python.platform import test
rng = np.random.RandomState(42)
+class _EffectiveSampleSizeTest(object):
+
+ @property
+ def use_static_shape(self):
+ raise NotImplementedError(
+ "Subclass failed to implement `use_static_shape`.")
+
+ def _check_versus_expected_effective_sample_size(self,
+ x_,
+ expected_ess,
+ sess,
+ atol=1e-2,
+ rtol=1e-2,
+ max_lags_threshold=None,
+ max_lags=None):
+ x = array_ops.placeholder_with_default(
+ input=x_, shape=x_.shape if self.use_static_shape else None)
+ ess = mcmc_diagnostics.effective_sample_size(
+ x, max_lags_threshold=max_lags_threshold, max_lags=max_lags)
+ if self.use_static_shape:
+ self.assertAllEqual(x.shape[1:], ess.shape)
+
+ ess_ = sess.run(ess)
+
+ self.assertAllClose(
+ np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol)
+
+ def testIidRank1NormalHasFullEssMaxLags10(self):
+ # With a length 5000 iid normal sequence, and max_lags = 10, we should
+ # have a good estimate of ESS, and it should be close to the full sequence
+ # length of 5000.
+ # The choice of max_lags = 10 is a short cutoff, reasonable only since we
+ # know the correlation length should be zero right away.
+ with self.test_session() as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ self._check_versus_expected_effective_sample_size(
+ x_=rng.randn(5000).astype(np.float32),
+ expected_ess=5000,
+ sess=sess,
+ max_lags=10,
+ rtol=0.3)
+
+ def testIidRank2NormalHasFullEssMaxLags10(self):
+ # See similar test for Rank1Normal for reasoning.
+ with self.test_session() as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ self._check_versus_expected_effective_sample_size(
+ x_=rng.randn(5000, 2).astype(np.float32),
+ expected_ess=5000,
+ sess=sess,
+ max_lags=10,
+ rtol=0.3)
+
+ def testIidRank1NormalHasFullEssMaxLagThresholdZero(self):
+ # With a length 5000 iid normal sequence, and max_lags_threshold = 0,
+ # we should have a super-duper estimate of ESS, and it should be very close
+ # to the full sequence length of 5000.
+ # The choice of max_lags_cutoff = 0 means we cutoff as soon as the auto-corr
+ # is below zero. This should happen very quickly, due to the fact that the
+ # theoretical auto-corr is [1, 0, 0,...]
+ with self.test_session() as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ self._check_versus_expected_effective_sample_size(
+ x_=rng.randn(5000).astype(np.float32),
+ expected_ess=5000,
+ sess=sess,
+ max_lags_threshold=0.,
+ rtol=0.1)
+
+ def testIidRank2NormalHasFullEssMaxLagThresholdZero(self):
+ # See similar test for Rank1Normal for reasoning.
+ with self.test_session() as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ self._check_versus_expected_effective_sample_size(
+ x_=rng.randn(5000, 2).astype(np.float32),
+ expected_ess=5000,
+ sess=sess,
+ max_lags_threshold=0.,
+ rtol=0.1)
+
+ def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self):
+ # Create x_, such that
+ # x_[i] = iid_x_[0], i = 0,...,9
+ # x_[i] = iid_x_[1], i = 10,..., 19,
+ # and so on.
+ iid_x_ = rng.randn(5000, 1).astype(np.float32)
+ x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,))
+ with self.test_session() as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ self._check_versus_expected_effective_sample_size(
+ x_=x_,
+ expected_ess=50000 // 10,
+ sess=sess,
+ max_lags=50,
+ rtol=0.2)
+
+ def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero(
+ self):
+ # Create x_, such that
+ # x_[i] = iid_x_[0], i = 0,...,9
+ # x_[i] = iid_x_[1], i = 10,..., 19,
+ # and so on.
+ iid_x_ = rng.randn(5000, 1).astype(np.float32)
+ x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,))
+ with self.test_session() as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ self._check_versus_expected_effective_sample_size(
+ x_=x_,
+ expected_ess=50000 // 10,
+ sess=sess,
+ max_lags_threshold=0.,
+ rtol=0.1)
+
+ def testListArgs(self):
+ # x_ has correlation length 10 ==> ESS = N / 10
+ # y_ has correlation length 1 ==> ESS = N
+ iid_x_ = rng.randn(5000, 1).astype(np.float32)
+ x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,))
+ y_ = rng.randn(50000).astype(np.float32)
+ states = [x_, x_, y_, y_]
+ max_lags_threshold = [0., None, 0., None]
+ max_lags = [None, 5, None, 5]
+
+ # See other tests for reasoning on tolerance.
+ with self.test_session() as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ ess = mcmc_diagnostics.effective_sample_size(
+ states,
+ max_lags_threshold=max_lags_threshold,
+ max_lags=max_lags)
+ ess_ = sess.run(ess)
+ self.assertAllEqual(4, len(ess_))
+
+ self.assertAllClose(50000 // 10, ess_[0], rtol=0.3)
+ self.assertAllClose(50000 // 10, ess_[1], rtol=0.3)
+ self.assertAllClose(50000, ess_[2], rtol=0.1)
+ self.assertAllClose(50000, ess_[3], rtol=0.1)
+
+
+class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest):
+
+ @property
+ def use_static_shape(self):
+ return True
+
+
+class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest):
+
+ @property
+ def use_static_shape(self):
+ return False
+
+
class _PotentialScaleReductionTest(object):
@property
@@ -48,7 +202,7 @@ class _PotentialScaleReductionTest(object):
state_1 = rng.randn(n_samples, 3, 4) + offset
rhat = mcmc_diagnostics.potential_scale_reduction(
- state=[state_0, state_1], independent_chain_ndims=1)
+ chains_states=[state_0, state_1], independent_chain_ndims=1)
self.assertIsInstance(rhat, list)
with self.test_session() as sess:
diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py
index 5f3e6ade70..f3a645eafc 100644
--- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py
+++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import *
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
+ "effective_sample_size",
"potential_scale_reduction",
]
diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py
index 3b6f92463e..bb8b915a9b 100644
--- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.
+@@effective_sample_size
@@potential_scale_reduction
"""
@@ -21,20 +22,189 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.distributions.python.ops import sample_stats
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
__all__ = [
+ "effective_sample_size",
"potential_scale_reduction",
]
-def potential_scale_reduction(state, independent_chain_ndims=1, name=None):
+def effective_sample_size(states,
+ max_lags_threshold=None,
+ max_lags=None,
+ name=None):
+ """Estimate a lower bound on effective sample size for each independent chain.
+
+ Roughly speaking, the "effective sample size" (ESS) is the size of an iid
+ sample with the same variance as `state`.
+
+ More precisely, given a stationary sequence of possibly correlated random
+ variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number
+ such that
+
+ ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.```
+
+ If the sequence is uncorrelated, `ESS = N`. In general, one should expect
+ `ESS <= N`, with more highly correlated sequences having smaller `ESS`.
+
+ #### Example of using ESS to estimate standard error.
+
+ ```
+ tfd = tf.contrib.distributions
+ tfb = tf.contrib.bayesflow
+
+ target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])
+
+ # Get 1000 states from one chain.
+ states = tfb.hmc.sample_chain(
+ num_results=1000,
+ target_log_prob_fn=target.log_prob,
+ current_state=tf.constant([0., 0.]),
+ step_size=0.05,
+ num_leapfrog_steps=20,
+ num_burnin_steps=200)
+ states.shape
+ ==> (1000, 2)
+
+ ess = effective_sample_size(states)
+ ==> Shape (2,) Tensor
+
+ mean, variance = tf.nn.moments(states, axis=0)
+ standard_error = tf.sqrt(variance / ess)
+ ```
+
+ Some math shows that, with `R_k` the auto-correlation sequence,
+ `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have
+
+ ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]```
+
+ This function estimates the above by first estimating the auto-correlation.
+ Since `R_k` must be estimated using only `N - k` samples, it becomes
+ progressively noisier for larger `k`. For this reason, the summation over
+ `R_k` should be truncated at some number `max_lags < N`. Since many MCMC
+ methods generate chains where `R_k > 0`, a reasonable critera is to truncate
+ at the first index where the estimated auto-correlation becomes negative.
+
+ Args:
+ states: `Tensor` or list of `Tensor` objects. Dimension zero should index
+ identically distributed states.
+ max_lags_threshold: `Tensor` or list of `Tensor` objects.
+ Must broadcast with `state`. The auto-correlation sequence is truncated
+ after the first appearance of a term less than `max_lags_threshold`. If
+ both `max_lags` and `max_lags_threshold` are `None`,
+ `max_lags_threshold` defaults to `0`.
+ max_lags: `Tensor` or list of `Tensor` objects. Must be `int`-like and
+ scalar valued. The auto-correlation sequence is truncated to this length.
+ May be provided only if `max_lags_threshold` is not.
+ name: `String` name to prepend to created ops.
+
+ Returns:
+ ess: `Tensor` or list of `Tensor` objects. The effective sample size of
+ each component of `states`. Shape will be `states.shape[1:]`.
+
+ Raises:
+ ValueError: If `states` and `max_lags_threshold` or `states` and `max_lags`
+ are both lists with different lengths.
+ """
+ states_was_list = _is_list_like(states)
+
+ # Convert all args to lists.
+ if not states_was_list:
+ states = [states]
+
+ max_lags = _broadcast_maybelist_arg(states, max_lags, "max_lags")
+ max_lags_threshold = _broadcast_maybelist_arg(states, max_lags_threshold,
+ "max_lags_threshold")
+
+ # Process items, one at a time.
+ with ops.name_scope(name, "effective_sample_size"):
+ ess_list = [
+ _effective_sample_size_single_state(s, ml, mlt)
+ for (s, ml, mlt) in zip(states, max_lags, max_lags_threshold)
+ ]
+
+ if states_was_list:
+ return ess_list
+ return ess_list[0]
+
+
+def _effective_sample_size_single_state(states, max_lags, max_lags_threshold):
+ """ESS computation for one single Tensor argument."""
+ if max_lags is not None and max_lags_threshold is not None:
+ raise ValueError(
+ "Expected at most one of max_lags, max_lags_threshold to be provided. "
+ "Found: {}, {}".format(max_lags, max_lags_threshold))
+
+ if max_lags_threshold is None:
+ max_lags_threshold = 0.
+
+ with ops.name_scope(
+ "effective_sample_size_single_state",
+ values=[states, max_lags, max_lags_threshold]):
+
+ states = ops.convert_to_tensor(states, name="states")
+ dt = states.dtype
+
+ if max_lags is not None:
+ auto_corr = sample_stats.auto_correlation(
+ states, axis=0, max_lags=max_lags)
+ elif max_lags_threshold is not None:
+ max_lags_threshold = ops.convert_to_tensor(
+ max_lags_threshold, dtype=dt, name="max_lags_threshold")
+ auto_corr = sample_stats.auto_correlation(states, axis=0)
+ # Get a binary mask to zero out values of auto_corr below the threshold.
+ # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i,
+ # mask[i, ...] = 0, otherwise.
+ # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...]
+ # Building step by step,
+ # Assume auto_corr = [1, 0.5, 0.0, 0.3], and max_lags_threshold = 0.2.
+ # Step 1: mask = [False, False, True, False]
+ mask = auto_corr < max_lags_threshold
+ # Step 2: mask = [0, 0, 1, 1]
+ mask = math_ops.cast(mask, dtype=dt)
+ # Step 3: mask = [0, 0, 1, 2]
+ mask = math_ops.cumsum(mask, axis=0)
+ # Step 4: mask = [1, 1, 0, 0]
+ mask = math_ops.maximum(1. - mask, 0.)
+ auto_corr *= mask
+ else:
+ auto_corr = sample_stats.auto_correlation(states, axis=0)
+
+ # With R[k] := auto_corr[k, ...],
+ # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]}
+ # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1)
+ # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]}
+ #, where M is the max_lags truncation point chosen above.
+
+ # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total
+ # ndims the same as auto_corr
+ n = _axis_size(states, axis=0)
+ k = math_ops.range(0., _axis_size(auto_corr, axis=0))
+ nk_factor = (n - k) / n
+ if auto_corr.shape.ndims is not None:
+ new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1)
+ else:
+ new_shape = array_ops.concat(
+ ([-1],
+ array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)),
+ axis=0)
+ nk_factor = array_ops.reshape(nk_factor, new_shape)
+
+ return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0))
+
+
+def potential_scale_reduction(chains_states,
+ independent_chain_ndims=1,
+ name=None):
"""Gelman and Rubin's potential scale reduction factor for chain convergence.
- Given `N > 1` samples from each of `C > 1` independent chains, the potential
+ Given `N > 1` states from each of `C > 1` independent chains, the potential
scale reduction factor, commonly referred to as R-hat, measures convergence of
the chains (to the same target) by testing for equality of means.
Specifically, R-hat measures the degree to which variance (of the means)
@@ -71,18 +241,18 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None):
==> (10, 2)
# Get 1000 samples from the 10 independent chains.
- state = tfb.hmc.sample_chain(
+ chains_states, _ = tfb.hmc.sample_chain(
num_results=1000,
target_log_prob_fn=target.log_prob,
current_state=initial_state,
step_size=0.05,
num_leapfrog_steps=20,
num_burnin_steps=200)
- state.shape
+ chains_states.shape
==> (1000, 10, 2)
rhat = tfb.mcmc_diagnostics.potential_scale_reduction(
- state, independent_chain_ndims=1)
+ chains_states, independent_chain_ndims=1)
# The second dimension needed a longer burn-in.
rhat.eval()
@@ -108,9 +278,9 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None):
Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4.
Args:
- state: `Tensor` or Python `list` of `Tensor`s representing the state(s) of
- a Markov Chain at each result step. The `ith` state is assumed to have
- shape `[Ni, Ci1, Ci2,...,CiD] + A`.
+ chains_states: `Tensor` or Python `list` of `Tensor`s representing the
+ state(s) of a Markov Chain at each result step. The `ith` state is
+ assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`.
Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain.
Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent
chains to be tested for convergence to the same target.
@@ -129,6 +299,10 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None):
Raises:
ValueError: If `independent_chain_ndims < 1`.
"""
+ chains_states_was_list = _is_list_like(chains_states)
+ if not chains_states_was_list:
+ chains_states = [chains_states]
+
# tensor_util.constant_value returns None iff a constant value (as a numpy
# array) is not efficiently computable. Therefore, we try constant_value then
# check for None.
@@ -140,66 +314,53 @@ def potential_scale_reduction(state, independent_chain_ndims=1, name=None):
raise ValueError(
"Argument `independent_chain_ndims` must be `>= 1`, found: {}".format(
independent_chain_ndims))
- with ops.name_scope(
- name,
- "potential_scale_reduction",
- values=[state, independent_chain_ndims]):
- if _is_list_like(state):
- return [
- _potential_scale_reduction_single_state(s, independent_chain_ndims)
- for s in state
- ]
- return _potential_scale_reduction_single_state(state,
- independent_chain_ndims)
+
+ with ops.name_scope(name, "potential_scale_reduction"):
+ rhat_list = [
+ _potential_scale_reduction_single_state(s, independent_chain_ndims)
+ for s in chains_states
+ ]
+
+ if chains_states_was_list:
+ return rhat_list
+ return rhat_list[0]
def _potential_scale_reduction_single_state(state, independent_chain_ndims):
"""potential_scale_reduction for one single state `Tensor`."""
- # We assume exactly one leading dimension indexes e.g. correlated samples from
- # each Markov chain.
- state = ops.convert_to_tensor(state, name="state")
- sample_ndims = 1
-
- sample_axis = math_ops.range(0, sample_ndims)
- chain_axis = math_ops.range(sample_ndims,
- sample_ndims + independent_chain_ndims)
- sample_and_chain_axis = math_ops.range(0,
- sample_ndims + independent_chain_ndims)
-
- n = _axis_size(state, sample_axis)
- m = _axis_size(state, chain_axis)
-
- # In the language of [2],
- # B / n is the between chain variance, the variance of the chain means.
- # W is the within sequence variance, the mean of the chain variances.
- b_div_n = _reduce_variance(
- math_ops.reduce_mean(state, sample_axis, keepdims=True),
- sample_and_chain_axis,
- biased=False)
- w = math_ops.reduce_mean(
- _reduce_variance(state, sample_axis, keepdims=True, biased=True),
- sample_and_chain_axis)
-
- # sigma^2_+ is an estimate of the true variance, which would be unbiased if
- # each chain was drawn from the target. c.f. "law of total variance."
- sigma_2_plus = w + b_div_n
-
- return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
-
-
-def effective_sample_size(state,
- independent_chain_ndims=1,
- max_lags=None,
- max_lags_threshold=None,
- name="effective_sample_size"):
- if max_lags is not None and max_lags_threshold is not None:
- raise ValueError(
- "Expected at most one of max_lags, max_lags_threshold to be provided. "
- "Found: {}, {}".format(max_lags, max_lags_threshold))
with ops.name_scope(
- name,
- values=[state, independent_chain_ndims, max_lags, max_lags_threshold]):
- pass
+ "potential_scale_reduction_single_state",
+ values=[state, independent_chain_ndims]):
+ # We assume exactly one leading dimension indexes e.g. correlated samples
+ # from each Markov chain.
+ state = ops.convert_to_tensor(state, name="state")
+ sample_ndims = 1
+
+ sample_axis = math_ops.range(0, sample_ndims)
+ chain_axis = math_ops.range(sample_ndims,
+ sample_ndims + independent_chain_ndims)
+ sample_and_chain_axis = math_ops.range(
+ 0, sample_ndims + independent_chain_ndims)
+
+ n = _axis_size(state, sample_axis)
+ m = _axis_size(state, chain_axis)
+
+ # In the language of [2],
+ # B / n is the between chain variance, the variance of the chain means.
+ # W is the within sequence variance, the mean of the chain variances.
+ b_div_n = _reduce_variance(
+ math_ops.reduce_mean(state, sample_axis, keepdims=True),
+ sample_and_chain_axis,
+ biased=False)
+ w = math_ops.reduce_mean(
+ _reduce_variance(state, sample_axis, keepdims=True, biased=True),
+ sample_and_chain_axis)
+
+ # sigma^2_+ is an estimate of the true variance, which would be unbiased if
+ # each chain was drawn from the target. c.f. "law of total variance."
+ sigma_2_plus = w + b_div_n
+
+ return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
# TODO(b/72873233) Move some variant of this to sample_stats.
@@ -226,3 +387,15 @@ def _axis_size(x, axis=None):
def _is_list_like(x):
"""Helper which returns `True` if input is `list`-like."""
return isinstance(x, (tuple, list))
+
+
+def _broadcast_maybelist_arg(states, secondary_arg, name):
+ """Broadcast a listable secondary_arg to that of states."""
+ if _is_list_like(secondary_arg):
+ if len(secondary_arg) != len(states):
+ raise ValueError("Argument `%s` was a list of different length ({}) than "
+ "`states` ({})".format(name, len(states)))
+ else:
+ secondary_arg = [secondary_arg] * len(states)
+
+ return secondary_arg