aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/bayesflow
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-03-06 11:55:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 11:59:25 -0800
commitc8236883db3b53563b24d527aade12e60d5ed246 (patch)
tree1a127b295e7b7f6473046f45826d97ce3f8221d4 /tensorflow/contrib/bayesflow
parent429ce2a60b9faa3db204aed05ab4a9a3a1a6c725 (diff)
Migrate MCMC diagnostics and Halton Sequence sampler into
tensorflow_probability. PiperOrigin-RevId: 188057302
Diffstat (limited to 'tensorflow/contrib/bayesflow')
-rw-r--r--tensorflow/contrib/bayesflow/BUILD20
-rw-r--r--tensorflow/contrib/bayesflow/__init__.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py445
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py32
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py400
5 files changed, 0 insertions, 899 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index 7302c9119d..2a32ea6952 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -125,26 +125,6 @@ cuda_py_test(
)
cuda_py_test(
- name = "mcmc_diagnostics_test",
- size = "small",
- srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"],
- additional_deps = [
- ":bayesflow_py",
- "//third_party/py/numpy",
- "//tensorflow/python:spectral_ops_test_util",
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python/ops/distributions",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:random_seed",
- ],
-)
-
-cuda_py_test(
name = "monte_carlo_test",
size = "small",
srcs = ["python/kernel_tests/monte_carlo_test.py"],
diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py
index f2b7fb77a8..156a2ef8cf 100644
--- a/tensorflow/contrib/bayesflow/__init__.py
+++ b/tensorflow/contrib/bayesflow/__init__.py
@@ -25,7 +25,6 @@ from tensorflow.contrib.bayesflow.python.ops import custom_grad
from tensorflow.contrib.bayesflow.python.ops import halton_sequence
from tensorflow.contrib.bayesflow.python.ops import hmc
from tensorflow.contrib.bayesflow.python.ops import layers
-from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
from tensorflow.contrib.bayesflow.python.ops import optimizers
@@ -41,7 +40,6 @@ _allowed_symbols = [
'hmc',
'layers',
'metropolis_hastings',
- 'mcmc_diagnostics',
'monte_carlo',
'optimizers',
'special_math',
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py
deleted file mode 100644
index 52e36e135d..0000000000
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py
+++ /dev/null
@@ -1,445 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MCMC diagnostic utilities."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import spectral_ops_test_util
-from tensorflow.python.platform import test
-
-rng = np.random.RandomState(42)
-
-
-class _EffectiveSampleSizeTest(object):
-
- @property
- def use_static_shape(self):
- raise NotImplementedError(
- "Subclass failed to implement `use_static_shape`.")
-
- def _check_versus_expected_effective_sample_size(self,
- x_,
- expected_ess,
- sess,
- atol=1e-2,
- rtol=1e-2,
- filter_threshold=None,
- filter_beyond_lag=None):
- x = array_ops.placeholder_with_default(
- input=x_, shape=x_.shape if self.use_static_shape else None)
- ess = mcmc_diagnostics.effective_sample_size(
- x,
- filter_threshold=filter_threshold,
- filter_beyond_lag=filter_beyond_lag)
- if self.use_static_shape:
- self.assertAllEqual(x.shape[1:], ess.shape)
-
- ess_ = sess.run(ess)
-
- self.assertAllClose(
- np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol)
-
- def testIidRank1NormalHasFullEssMaxLags10(self):
- # With a length 5000 iid normal sequence, and filter_beyond_lag = 10, we
- # should have a good estimate of ESS, and it should be close to the full
- # sequence length of 5000.
- # The choice of filter_beyond_lag = 10 is a short cutoff, reasonable only
- # since we know the correlation length should be zero right away.
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- self._check_versus_expected_effective_sample_size(
- x_=rng.randn(5000).astype(np.float32),
- expected_ess=5000,
- sess=sess,
- filter_beyond_lag=10,
- filter_threshold=None,
- rtol=0.3)
-
- def testIidRank2NormalHasFullEssMaxLags10(self):
- # See similar test for Rank1Normal for reasoning.
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- self._check_versus_expected_effective_sample_size(
- x_=rng.randn(5000, 2).astype(np.float32),
- expected_ess=5000,
- sess=sess,
- filter_beyond_lag=10,
- filter_threshold=None,
- rtol=0.3)
-
- def testIidRank1NormalHasFullEssMaxLagThresholdZero(self):
- # With a length 5000 iid normal sequence, and filter_threshold = 0,
- # we should have a super-duper estimate of ESS, and it should be very close
- # to the full sequence length of 5000.
- # The choice of filter_beyond_lag = 0 means we cutoff as soon as the
- # auto-corris below zero. This should happen very quickly, due to the fact
- # that the theoretical auto-corr is [1, 0, 0,...]
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- self._check_versus_expected_effective_sample_size(
- x_=rng.randn(5000).astype(np.float32),
- expected_ess=5000,
- sess=sess,
- filter_beyond_lag=None,
- filter_threshold=0.,
- rtol=0.1)
-
- def testIidRank2NormalHasFullEssMaxLagThresholdZero(self):
- # See similar test for Rank1Normal for reasoning.
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- self._check_versus_expected_effective_sample_size(
- x_=rng.randn(5000, 2).astype(np.float32),
- expected_ess=5000,
- sess=sess,
- filter_beyond_lag=None,
- filter_threshold=0.,
- rtol=0.1)
-
- def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self):
- # Create x_, such that
- # x_[i] = iid_x_[0], i = 0,...,9
- # x_[i] = iid_x_[1], i = 10,..., 19,
- # and so on.
- iid_x_ = rng.randn(5000, 1).astype(np.float32)
- x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,))
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- self._check_versus_expected_effective_sample_size(
- x_=x_,
- expected_ess=50000 // 10,
- sess=sess,
- filter_beyond_lag=50,
- filter_threshold=None,
- rtol=0.2)
-
- def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero(
- self):
- # Create x_, such that
- # x_[i] = iid_x_[0], i = 0,...,9
- # x_[i] = iid_x_[1], i = 10,..., 19,
- # and so on.
- iid_x_ = rng.randn(5000, 1).astype(np.float32)
- x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,))
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- self._check_versus_expected_effective_sample_size(
- x_=x_,
- expected_ess=50000 // 10,
- sess=sess,
- filter_beyond_lag=None,
- filter_threshold=0.,
- rtol=0.1)
-
- def testListArgs(self):
- # x_ has correlation length 10 ==> ESS = N / 10
- # y_ has correlation length 1 ==> ESS = N
- iid_x_ = rng.randn(5000, 1).astype(np.float32)
- x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,))
- y_ = rng.randn(50000).astype(np.float32)
- states = [x_, x_, y_, y_]
- filter_threshold = [0., None, 0., None]
- filter_beyond_lag = [None, 5, None, 5]
-
- # See other tests for reasoning on tolerance.
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- ess = mcmc_diagnostics.effective_sample_size(
- states,
- filter_threshold=filter_threshold,
- filter_beyond_lag=filter_beyond_lag)
- ess_ = sess.run(ess)
- self.assertAllEqual(4, len(ess_))
-
- self.assertAllClose(50000 // 10, ess_[0], rtol=0.3)
- self.assertAllClose(50000 // 10, ess_[1], rtol=0.3)
- self.assertAllClose(50000, ess_[2], rtol=0.1)
- self.assertAllClose(50000, ess_[3], rtol=0.1)
-
- def testMaxLagsThresholdLessThanNeg1SameAsNone(self):
- # Setting both means we filter out items R_k from the auto-correlation
- # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold.
-
- # x_ has correlation length 10.
- iid_x_ = rng.randn(500, 1).astype(np.float32)
- x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,))
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- x = array_ops.placeholder_with_default(
- input=x_, shape=x_.shape if self.use_static_shape else None)
-
- ess_none_none = mcmc_diagnostics.effective_sample_size(
- x, filter_threshold=None, filter_beyond_lag=None)
- ess_none_200 = mcmc_diagnostics.effective_sample_size(
- x, filter_threshold=None, filter_beyond_lag=200)
- ess_neg2_200 = mcmc_diagnostics.effective_sample_size(
- x, filter_threshold=-2., filter_beyond_lag=200)
- ess_neg2_none = mcmc_diagnostics.effective_sample_size(
- x, filter_threshold=-2., filter_beyond_lag=None)
- ess_none_none_, ess_none_200_, ess_neg2_200_, ess_neg2_none_ = sess.run(
- [ess_none_none, ess_none_200, ess_neg2_200, ess_neg2_none])
-
- # filter_threshold=-2 <==> filter_threshold=None.
- self.assertAllClose(ess_none_none_, ess_neg2_none_)
- self.assertAllClose(ess_none_200_, ess_neg2_200_)
-
- def testMaxLagsArgsAddInAnOrManner(self):
- # Setting both means we filter out items R_k from the auto-correlation
- # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold.
-
- # x_ has correlation length 10.
- iid_x_ = rng.randn(500, 1).astype(np.float32)
- x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,))
- with self.test_session() as sess:
- with spectral_ops_test_util.fft_kernel_label_map():
- x = array_ops.placeholder_with_default(
- input=x_, shape=x_.shape if self.use_static_shape else None)
-
- ess_1_9 = mcmc_diagnostics.effective_sample_size(
- x, filter_threshold=1., filter_beyond_lag=9)
- ess_1_none = mcmc_diagnostics.effective_sample_size(
- x, filter_threshold=1., filter_beyond_lag=None)
- ess_none_9 = mcmc_diagnostics.effective_sample_size(
- x, filter_threshold=1., filter_beyond_lag=9)
- ess_1_9_, ess_1_none_, ess_none_9_ = sess.run(
- [ess_1_9, ess_1_none, ess_none_9])
-
- # Since R_k = 1 for k < 10, and R_k < 1 for k >= 10,
- # filter_threshold = 1 <==> filter_beyond_lag = 9.
- self.assertAllClose(ess_1_9_, ess_1_none_)
- self.assertAllClose(ess_1_9_, ess_none_9_)
-
-
-class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest):
-
- @property
- def use_static_shape(self):
- return True
-
-
-class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest):
-
- @property
- def use_static_shape(self):
- return False
-
-
-class _PotentialScaleReductionTest(object):
-
- @property
- def use_static_shape(self):
- raise NotImplementedError(
- "Subclass failed to impliment `use_static_shape`.")
-
- def testListOfStatesWhereFirstPassesSecondFails(self):
- """Simple test showing API with two states. Read first!."""
- n_samples = 1000
-
- # state_0 is two scalar chains taken from iid Normal(0, 1). Will pass.
- state_0 = rng.randn(n_samples, 2)
-
- # state_1 is three 4-variate chains taken from Normal(0, 1) that have been
- # shifted. Since every chain is shifted, they are not the same, and the
- # test should fail.
- offset = np.array([1., -1., 2.]).reshape(3, 1)
- state_1 = rng.randn(n_samples, 3, 4) + offset
-
- rhat = mcmc_diagnostics.potential_scale_reduction(
- chains_states=[state_0, state_1], independent_chain_ndims=1)
-
- self.assertIsInstance(rhat, list)
- with self.test_session() as sess:
- rhat_0_, rhat_1_ = sess.run(rhat)
-
- # r_hat_0 should be close to 1, meaning test is passed.
- self.assertAllEqual((), rhat_0_.shape)
- self.assertAllClose(1., rhat_0_, rtol=0.02)
-
- # r_hat_1 should be greater than 1.2, meaning test has failed.
- self.assertAllEqual((4,), rhat_1_.shape)
- self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2)
-
- def check_results(self, state_, independent_chain_shape, should_pass):
- sample_ndims = 1
- independent_chain_ndims = len(independent_chain_shape)
- with self.test_session():
- state = array_ops.placeholder_with_default(
- input=state_, shape=state_.shape if self.use_static_shape else None)
-
- rhat = mcmc_diagnostics.potential_scale_reduction(
- state, independent_chain_ndims=independent_chain_ndims)
-
- if self.use_static_shape:
- self.assertAllEqual(
- state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape)
-
- rhat_ = rhat.eval()
- if should_pass:
- self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02)
- else:
- self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2)
-
- def iid_normal_chains_should_pass_wrapper(self,
- sample_shape,
- independent_chain_shape,
- other_shape,
- dtype=np.float32):
- """Check results with iid normal chains."""
-
- state_shape = sample_shape + independent_chain_shape + other_shape
- state_ = rng.randn(*state_shape).astype(dtype)
-
- # The "other" dimensions do not have to be identical, just independent, so
- # force them to not be identical.
- if other_shape:
- state_ *= rng.rand(*other_shape).astype(dtype)
-
- self.check_results(state_, independent_chain_shape, should_pass=True)
-
- def testPassingIIDNdimsAreIndependentOneOtherZero(self):
- self.iid_normal_chains_should_pass_wrapper(
- sample_shape=[10000], independent_chain_shape=[4], other_shape=[])
-
- def testPassingIIDNdimsAreIndependentOneOtherOne(self):
- self.iid_normal_chains_should_pass_wrapper(
- sample_shape=[10000], independent_chain_shape=[3], other_shape=[7])
-
- def testPassingIIDNdimsAreIndependentOneOtherTwo(self):
- self.iid_normal_chains_should_pass_wrapper(
- sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7])
-
- def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self):
- self.iid_normal_chains_should_pass_wrapper(
- sample_shape=[10000],
- independent_chain_shape=[2, 3],
- other_shape=[5, 7],
- dtype=np.float64)
-
- def offset_normal_chains_should_fail_wrapper(
- self, sample_shape, independent_chain_shape, other_shape):
- """Check results with normal chains that are offset from each other."""
-
- state_shape = sample_shape + independent_chain_shape + other_shape
- state_ = rng.randn(*state_shape)
-
- # Add a significant offset to the different (formerly iid) chains.
- offset = np.linspace(
- 0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len(
- sample_shape) + independent_chain_shape + [1] * len(other_shape))
- state_ += offset
-
- self.check_results(state_, independent_chain_shape, should_pass=False)
-
- def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self):
- self.offset_normal_chains_should_fail_wrapper(
- sample_shape=[10000], independent_chain_shape=[2], other_shape=[5])
-
-
-class PotentialScaleReductionStaticTest(test.TestCase,
- _PotentialScaleReductionTest):
-
- @property
- def use_static_shape(self):
- return True
-
- def testIndependentNdimsLessThanOneRaises(self):
- with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"):
- mcmc_diagnostics.potential_scale_reduction(
- rng.rand(2, 3, 4), independent_chain_ndims=0)
-
-
-class PotentialScaleReductionDynamicTest(test.TestCase,
- _PotentialScaleReductionTest):
-
- @property
- def use_static_shape(self):
- return False
-
-
-class _ReduceVarianceTest(object):
-
- @property
- def use_static_shape(self):
- raise NotImplementedError(
- "Subclass failed to impliment `use_static_shape`.")
-
- def check_versus_numpy(self, x_, axis, biased, keepdims):
- with self.test_session():
- x_ = np.asarray(x_)
- x = array_ops.placeholder_with_default(
- input=x_, shape=x_.shape if self.use_static_shape else None)
- var = mcmc_diagnostics._reduce_variance(
- x, axis=axis, biased=biased, keepdims=keepdims)
- np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims)
-
- if self.use_static_shape:
- self.assertAllEqual(np_var.shape, var.shape)
-
- var_ = var.eval()
- # We will mask below, which changes shape, so check shape explicitly here.
- self.assertAllEqual(np_var.shape, var_.shape)
-
- # We get NaN when we divide by zero due to the size being the same as ddof
- nan_mask = np.isnan(np_var)
- if nan_mask.any():
- self.assertTrue(np.isnan(var_[nan_mask]).all())
- self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02)
-
- def testScalarBiasedTrue(self):
- self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False)
-
- def testScalarBiasedFalse(self):
- # This should result in NaN.
- self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False)
-
- def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self):
- self.check_versus_numpy(
- x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False)
-
- def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self):
- self.check_versus_numpy(
- x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True)
-
- def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self):
- self.check_versus_numpy(
- x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True)
-
- def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self):
- self.check_versus_numpy(
- x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False)
-
-
-class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest):
-
- @property
- def use_static_shape(self):
- return True
-
-
-class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest):
-
- @property
- def use_static_shape(self):
- return False
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py
deleted file mode 100644
index f3a645eafc..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities for Markov Chain Monte Carlo (MCMC) sampling."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import
-from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import *
-# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- "effective_sample_size",
- "potential_scale_reduction",
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py
deleted file mode 100644
index 0424b6952b..0000000000
--- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py
+++ /dev/null
@@ -1,400 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.
-
-@@effective_sample_size
-@@potential_scale_reduction
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distributions.python.ops import sample_stats
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-
-__all__ = [
- "effective_sample_size",
- "potential_scale_reduction",
-]
-
-
-def effective_sample_size(states,
- filter_threshold=0.,
- filter_beyond_lag=None,
- name=None):
- """Estimate a lower bound on effective sample size for each independent chain.
-
- Roughly speaking, "effective sample size" (ESS) is the size of an iid sample
- with the same variance as `state`.
-
- More precisely, given a stationary sequence of possibly correlated random
- variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number
- such that
-
- ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.```
-
- If the sequence is uncorrelated, `ESS = N`. In general, one should expect
- `ESS <= N`, with more highly correlated sequences having smaller `ESS`.
-
- #### Example of using ESS to estimate standard error.
-
- ```
- tfd = tf.contrib.distributions
- tfb = tf.contrib.bayesflow
-
- target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])
-
- # Get 1000 states from one chain.
- states = tfb.hmc.sample_chain(
- num_results=1000,
- target_log_prob_fn=target.log_prob,
- current_state=tf.constant([0., 0.]),
- step_size=0.05,
- num_leapfrog_steps=20,
- num_burnin_steps=200)
- states.shape
- ==> (1000, 2)
-
- ess = effective_sample_size(states)
- ==> Shape (2,) Tensor
-
- mean, variance = tf.nn.moments(states, axis=0)
- standard_error = tf.sqrt(variance / ess)
- ```
-
- Some math shows that, with `R_k` the auto-correlation sequence,
- `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have
-
- ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]```
-
- This function estimates the above by first estimating the auto-correlation.
- Since `R_k` must be estimated using only `N - k` samples, it becomes
- progressively noisier for larger `k`. For this reason, the summation over
- `R_k` should be truncated at some number `filter_beyond_lag < N`. Since many
- MCMC methods generate chains where `R_k > 0`, a reasonable critera is to
- truncate at the first index where the estimated auto-correlation becomes
- negative.
-
- The arguments `filter_beyond_lag`, `filter_threshold` are filters intended to
- remove noisy tail terms from `R_k`. They combine in an "OR" manner meaning
- terms are removed if they were to be filtered under the `filter_beyond_lag` OR
- `filter_threshold` criteria.
-
- Args:
- states: `Tensor` or list of `Tensor` objects. Dimension zero should index
- identically distributed states.
- filter_threshold: `Tensor` or list of `Tensor` objects.
- Must broadcast with `state`. The auto-correlation sequence is truncated
- after the first appearance of a term less than `filter_threshold`.
- Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`,
- setting to any number less than `-1` has the same effect.
- filter_beyond_lag: `Tensor` or list of `Tensor` objects. Must be
- `int`-like and scalar valued. The auto-correlation sequence is truncated
- to this length. Setting to `None` means we do not filter based on number
- of lags.
- name: `String` name to prepend to created ops.
-
- Returns:
- ess: `Tensor` or list of `Tensor` objects. The effective sample size of
- each component of `states`. Shape will be `states.shape[1:]`.
-
- Raises:
- ValueError: If `states` and `filter_threshold` or `states` and
- `filter_beyond_lag` are both lists with different lengths.
- """
- states_was_list = _is_list_like(states)
-
- # Convert all args to lists.
- if not states_was_list:
- states = [states]
-
- filter_beyond_lag = _broadcast_maybelist_arg(states, filter_beyond_lag,
- "filter_beyond_lag")
- filter_threshold = _broadcast_maybelist_arg(states, filter_threshold,
- "filter_threshold")
-
- # Process items, one at a time.
- with ops.name_scope(name, "effective_sample_size"):
- ess_list = [
- _effective_sample_size_single_state(s, ml, mlt)
- for (s, ml, mlt) in zip(states, filter_beyond_lag, filter_threshold)
- ]
-
- if states_was_list:
- return ess_list
- return ess_list[0]
-
-
-def _effective_sample_size_single_state(states, filter_beyond_lag,
- filter_threshold):
- """ESS computation for one single Tensor argument."""
-
- with ops.name_scope(
- "effective_sample_size_single_state",
- values=[states, filter_beyond_lag, filter_threshold]):
-
- states = ops.convert_to_tensor(states, name="states")
- dt = states.dtype
-
- # filter_beyond_lag == None ==> auto_corr is the full sequence.
- auto_corr = sample_stats.auto_correlation(
- states, axis=0, max_lags=filter_beyond_lag)
- if filter_threshold is not None:
- filter_threshold = ops.convert_to_tensor(
- filter_threshold, dtype=dt, name="filter_threshold")
- # Get a binary mask to zero out values of auto_corr below the threshold.
- # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i,
- # mask[i, ...] = 0, otherwise.
- # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...]
- # Building step by step,
- # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2.
- # Step 1: mask = [False, False, True, False]
- mask = auto_corr < filter_threshold
- # Step 2: mask = [0, 0, 1, 1]
- mask = math_ops.cast(mask, dtype=dt)
- # Step 3: mask = [0, 0, 1, 2]
- mask = math_ops.cumsum(mask, axis=0)
- # Step 4: mask = [1, 1, 0, 0]
- mask = math_ops.maximum(1. - mask, 0.)
- auto_corr *= mask
-
- # With R[k] := auto_corr[k, ...],
- # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]}
- # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1)
- # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]}
- # where M is the filter_beyond_lag truncation point chosen above.
-
- # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total
- # ndims the same as auto_corr
- n = _axis_size(states, axis=0)
- k = math_ops.range(0., _axis_size(auto_corr, axis=0))
- nk_factor = (n - k) / n
- if auto_corr.shape.ndims is not None:
- new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1)
- else:
- new_shape = array_ops.concat(
- ([-1],
- array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)),
- axis=0)
- nk_factor = array_ops.reshape(nk_factor, new_shape)
-
- return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0))
-
-
-def potential_scale_reduction(chains_states,
- independent_chain_ndims=1,
- name=None):
- """Gelman and Rubin's potential scale reduction factor for chain convergence.
-
- Given `N > 1` states from each of `C > 1` independent chains, the potential
- scale reduction factor, commonly referred to as R-hat, measures convergence of
- the chains (to the same target) by testing for equality of means.
- Specifically, R-hat measures the degree to which variance (of the means)
- between chains exceeds what one would expect if the chains were identically
- distributed. See [1], [2].
-
- Some guidelines:
-
- * The initial state of the chains should be drawn from a distribution
- overdispersed with respect to the target.
- * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1.
- Before that, R-hat > 1 (except in pathological cases, e.g. if the chain
- paths were identical).
- * The above holds for any number of chains `C > 1`. Increasing `C` does
- improves effectiveness of the diagnostic.
- * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of
- course this is problem depedendent. See [2].
- * R-hat only measures non-convergence of the mean. If higher moments, or other
- statistics are desired, a different diagnostic should be used. See [2].
-
- #### Examples
-
- Diagnosing convergence by monitoring 10 chains that each attempt to
- sample from a 2-variate normal.
-
- ```python
- tfd = tf.contrib.distributions
- tfb = tf.contrib.bayesflow
-
- target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])
-
- # Get 10 (2x) overdispersed initial states.
- initial_state = target.sample(10) * 2.
- ==> (10, 2)
-
- # Get 1000 samples from the 10 independent chains.
- chains_states, _ = tfb.hmc.sample_chain(
- num_results=1000,
- target_log_prob_fn=target.log_prob,
- current_state=initial_state,
- step_size=0.05,
- num_leapfrog_steps=20,
- num_burnin_steps=200)
- chains_states.shape
- ==> (1000, 10, 2)
-
- rhat = tfb.mcmc_diagnostics.potential_scale_reduction(
- chains_states, independent_chain_ndims=1)
-
- # The second dimension needed a longer burn-in.
- rhat.eval()
- ==> [1.05, 1.3]
- ```
-
- To see why R-hat is reasonable, let `X` be a random variable drawn uniformly
- from the combined states (combined over all chains). Then, in the limit
- `N, C --> infinity`, with `E`, `Var` denoting expectation and variance,
-
- ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].```
-
- Using the law of total variance, the numerator is the variance of the combined
- states, and the denominator is the total variance minus the variance of the
- the individual chain means. If the chains are all drawing from the same
- distribution, they will have the same mean, and thus the ratio should be one.
-
- [1] "Inference from Iterative Simulation Using Multiple Sequences"
- Andrew Gelman and Donald B. Rubin
- Statist. Sci. Volume 7, Number 4 (1992), 457-472.
- [2] "General Methods for Monitoring Convergence of Iterative Simulations"
- Stephen P. Brooks and Andrew Gelman
- Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4.
-
- Args:
- chains_states: `Tensor` or Python `list` of `Tensor`s representing the
- state(s) of a Markov Chain at each result step. The `ith` state is
- assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`.
- Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain.
- Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent
- chains to be tested for convergence to the same target.
- The remaining dimensions, `A`, can have any shape (even empty).
- independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the
- number of giving the number of dimensions, from `dim = 1` to `dim = D`,
- holding independent chain results to be tested for convergence.
- name: `String` name to prepend to created ops. Default:
- `potential_scale_reduction`.
-
- Returns:
- `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for
- the state(s). Same `dtype` as `state`, and shape equal to
- `state.shape[1 + independent_chain_ndims:]`.
-
- Raises:
- ValueError: If `independent_chain_ndims < 1`.
- """
- chains_states_was_list = _is_list_like(chains_states)
- if not chains_states_was_list:
- chains_states = [chains_states]
-
- # tensor_util.constant_value returns None iff a constant value (as a numpy
- # array) is not efficiently computable. Therefore, we try constant_value then
- # check for None.
- icn_const_ = tensor_util.constant_value(
- ops.convert_to_tensor(independent_chain_ndims))
- if icn_const_ is not None:
- independent_chain_ndims = icn_const_
- if icn_const_ < 1:
- raise ValueError(
- "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format(
- independent_chain_ndims))
-
- with ops.name_scope(name, "potential_scale_reduction"):
- rhat_list = [
- _potential_scale_reduction_single_state(s, independent_chain_ndims)
- for s in chains_states
- ]
-
- if chains_states_was_list:
- return rhat_list
- return rhat_list[0]
-
-
-def _potential_scale_reduction_single_state(state, independent_chain_ndims):
- """potential_scale_reduction for one single state `Tensor`."""
- with ops.name_scope(
- "potential_scale_reduction_single_state",
- values=[state, independent_chain_ndims]):
- # We assume exactly one leading dimension indexes e.g. correlated samples
- # from each Markov chain.
- state = ops.convert_to_tensor(state, name="state")
- sample_ndims = 1
-
- sample_axis = math_ops.range(0, sample_ndims)
- chain_axis = math_ops.range(sample_ndims,
- sample_ndims + independent_chain_ndims)
- sample_and_chain_axis = math_ops.range(
- 0, sample_ndims + independent_chain_ndims)
-
- n = _axis_size(state, sample_axis)
- m = _axis_size(state, chain_axis)
-
- # In the language of [2],
- # B / n is the between chain variance, the variance of the chain means.
- # W is the within sequence variance, the mean of the chain variances.
- b_div_n = _reduce_variance(
- math_ops.reduce_mean(state, sample_axis, keepdims=True),
- sample_and_chain_axis,
- biased=False)
- w = math_ops.reduce_mean(
- _reduce_variance(state, sample_axis, keepdims=True, biased=True),
- sample_and_chain_axis)
-
- # sigma^2_+ is an estimate of the true variance, which would be unbiased if
- # each chain was drawn from the target. c.f. "law of total variance."
- sigma_2_plus = w + b_div_n
-
- return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
-
-
-# TODO(b/72873233) Move some variant of this to sample_stats.
-def _reduce_variance(x, axis=None, biased=True, keepdims=False):
- with ops.name_scope("reduce_variance"):
- x = ops.convert_to_tensor(x, name="x")
- mean = math_ops.reduce_mean(x, axis=axis, keepdims=True)
- biased_var = math_ops.reduce_mean(
- math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims)
- if biased:
- return biased_var
- n = _axis_size(x, axis)
- return (n / (n - 1.)) * biased_var
-
-
-def _axis_size(x, axis=None):
- """Get number of elements of `x` in `axis`, as type `x.dtype`."""
- if axis is None:
- return math_ops.cast(array_ops.size(x), x.dtype)
- return math_ops.cast(
- math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype)
-
-
-def _is_list_like(x):
- """Helper which returns `True` if input is `list`-like."""
- return isinstance(x, (tuple, list))
-
-
-def _broadcast_maybelist_arg(states, secondary_arg, name):
- """Broadcast a listable secondary_arg to that of states."""
- if _is_list_like(secondary_arg):
- if len(secondary_arg) != len(states):
- raise ValueError("Argument `%s` was a list of different length ({}) than "
- "`states` ({})".format(name, len(states)))
- else:
- secondary_arg = [secondary_arg] * len(states)
-
- return secondary_arg