aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-11 12:27:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-11 12:31:45 -0800
commit743c7b17eeda2b4b13e5524168a096e871426a27 (patch)
tree5d298f1dc1fe7e76296b5d6831a84bd2dc13fc7e
parent692ee62a63a54d9fe02b3ef6e4e62de490046719 (diff)
Add support for Halton sequence for use with Monte Carlo estimation of integrals.
PiperOrigin-RevId: 175412461
-rw-r--r--tensorflow/contrib/bayesflow/BUILD19
-rw-r--r--tensorflow/contrib/bayesflow/__init__.py6
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py131
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/halton_sequence.py33
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py264
5 files changed, 451 insertions, 2 deletions
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index b024f158cd..f92b57869e 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -121,6 +121,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "halton_sequence_test",
+ size = "small",
+ srcs = ["python/kernel_tests/halton_sequence_test.py"],
+ additional_deps = [
+ ":bayesflow_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
name = "hmc_test",
size = "medium",
srcs = ["python/kernel_tests/hmc_test.py"],
diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py
index b98bc36954..beaf6f1854 100644
--- a/tensorflow/contrib/bayesflow/__init__.py
+++ b/tensorflow/contrib/bayesflow/__init__.py
@@ -23,6 +23,7 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence
from tensorflow.contrib.bayesflow.python.ops import custom_grad
+from tensorflow.contrib.bayesflow.python.ops import halton_sequence
from tensorflow.contrib.bayesflow.python.ops import hmc
from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
@@ -32,7 +33,8 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy',
- 'metropolis_hastings', 'monte_carlo', 'hmc', 'special_math',
- 'stochastic_variables', 'variational_inference']
+ 'metropolis_hastings', 'monte_carlo', 'halton_sequence',
+ 'hmc', 'special_math', 'stochastic_variables',
+ 'variational_inference']
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py
new file mode 100644
index 0000000000..0a85862abf
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for halton_sequence.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.bayesflow.python.ops import halton_sequence as halton
+from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.platform import test
+
+
+mc = monte_carlo_lib
+
+
+class HaltonSequenceTest(test.TestCase):
+
+ def test_known_values_small_bases(self):
+ with self.test_session():
+ # The first five elements of the Halton sequence with base 2 and 3
+ expected = np.array(((1. / 2, 1. / 3),
+ (1. / 4, 2. / 3),
+ (3. / 4, 1. / 9),
+ (1. / 8, 4. / 9),
+ (5. / 8, 7. / 9)), dtype=np.float32)
+ sample = halton.sample(2, num_samples=5)
+ self.assertAllClose(expected, sample.eval(), rtol=1e-6)
+
+ def test_sample_indices(self):
+ with self.test_session():
+ dim = 5
+ indices = math_ops.range(10, dtype=dtypes.int32)
+ sample_direct = halton.sample(dim, num_samples=10)
+ sample_from_indices = halton.sample(dim, sample_indices=indices)
+ self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(),
+ rtol=1e-6)
+
+ def test_dtypes_works_correctly(self):
+ with self.test_session():
+ dim = 3
+ sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32)
+ sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64)
+ self.assertEqual(sample_float32.eval().dtype, np.float32)
+ self.assertEqual(sample_float64.eval().dtype, np.float64)
+
+ def test_normal_integral_mean_and_var_correctly_estimated(self):
+ n = int(1000)
+ # This test is almost identical to the similarly named test in
+ # monte_carlo_test.py. The only difference is that we use the Halton
+ # samples instead of the random samples to evaluate the expectations.
+ # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N)
+ # (N=number of samples). For QMC in low dimensions, the expected convergence
+ # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the
+ # 1e6 samples used in the pseudo-random monte carlo.
+ with self.test_session():
+ mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64)
+ mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64)
+ sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64)
+ sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64)
+ p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
+ q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
+
+ cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64)
+ q_sample = q.quantile(cdf_sample)
+
+ # Compute E_p[X].
+ e_x = mc.expectation_importance_sampler(
+ f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample,
+ seed=42)
+
+ # Compute E_p[X^2].
+ e_x2 = mc.expectation_importance_sampler(
+ f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample,
+ seed=42)
+
+ stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x))
+ # Keep the tolerance levels the same as in monte_carlo_test.py.
+ self.assertEqual(p.batch_shape, e_x.get_shape())
+ self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01)
+ self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02)
+
+ def test_docstring_example(self):
+ # Produce the first 1000 members of the Halton sequence in 3 dimensions.
+ num_samples = 1000
+ dim = 3
+ with self.test_session():
+ sample = halton.sample(dim, num_samples=num_samples)
+
+ # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional
+ # hypercube.
+ powers = math_ops.range(1.0, limit=dim + 1)
+ integral = math_ops.reduce_mean(
+ math_ops.reduce_prod(sample ** powers, axis=-1))
+ true_value = 1.0 / math_ops.reduce_prod(powers + 1.0)
+
+ # Produces a relative absolute error of 1.7%.
+ self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02)
+
+ # Now skip the first 1000 samples and recompute the integral with the next
+ # thousand samples. The sample_indices argument can be used to do this.
+
+ sample_indices = math_ops.range(start=1000, limit=1000 + num_samples,
+ dtype=dtypes.int32)
+ sample_leaped = halton.sample(dim, sample_indices=sample_indices)
+
+ integral_leaped = math_ops.reduce_mean(
+ math_ops.reduce_prod(sample_leaped ** powers, axis=-1))
+ self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py
new file mode 100644
index 0000000000..49d747d538
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py
@@ -0,0 +1,33 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for low discrepancy Halton sequences.
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.bayesflow.python.ops.halton_sequence_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'sample',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py
new file mode 100644
index 0000000000..8cabf18903
--- /dev/null
+++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py
@@ -0,0 +1,264 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Quasi Monte Carlo support: Halton sequence.
+
+@@sample
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+__all__ = [
+ 'sample',
+]
+
+
+# The maximum dimension we support. This is limited by the number of primes
+# in the _PRIMES array.
+_MAX_DIMENSION = 1000
+
+
+def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None):
+ r"""Returns a sample from the `m` dimensional Halton sequence.
+
+ Warning: The sequence elements take values only between 0 and 1. Care must be
+ taken to appropriately transform the domain of a function if it differs from
+ the unit cube before evaluating integrals using Halton samples. It is also
+ important to remember that quasi-random numbers are not a replacement for
+ pseudo-random numbers in every context. Quasi random numbers are completely
+ deterministic and typically have significant negative autocorrelation (unless
+ randomized).
+
+ Computes the members of the low discrepancy Halton sequence in dimension
+ `dim`. The d-dimensional sequence takes values in the unit hypercube in d
+ dimensions. Currently, only dimensions up to 1000 are supported. The prime
+ base for the `k`-th axes is the k-th prime starting from 2. For example,
+ if dim = 3, then the bases will be [2, 3, 5] respectively and the first
+ element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete
+ description of the Halton sequences see:
+ https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences
+ and their applications see:
+ https://en.wikipedia.org/wiki/Low-discrepancy_sequence.
+
+ The user must supply either `num_samples` or `sample_indices` but not both.
+ The former is the number of samples to produce starting from the first
+ element. If `sample_indices` is given instead, the specified elements of
+ the sequence are generated. For example, sample_indices=tf.range(10) is
+ equivalent to specifying n=10.
+
+ Example Use:
+
+ ```python
+ bf = tf.contrib.bayesflow
+
+ # Produce the first 1000 members of the Halton sequence in 3 dimensions.
+ num_samples = 1000
+ dim = 3
+ sample = bf.halton_sequence.sample(dim, num_samples=num_samples)
+
+ # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional
+ # hypercube.
+ powers = tf.range(1.0, limit=dim + 1)
+ integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1))
+ true_value = 1.0 / tf.reduce_prod(powers + 1.0)
+ with tf.Session() as session:
+ values = session.run((integral, true_value))
+
+ # Produces a relative absolute error of 1.7%.
+ print ("Estimated: %f, True Value: %f" % values)
+
+ # Now skip the first 1000 samples and recompute the integral with the next
+ # thousand samples. The sample_indices argument can be used to do this.
+
+
+ sample_indices = tf.range(start=1000, limit=1000 + num_samples,
+ dtype=tf.int32)
+ sample_leaped = halton.sample(dim, sample_indices=sample_indices)
+
+ integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers,
+ axis=-1))
+ with tf.Session() as session:
+ values = session.run((integral_leaped, true_value))
+ # Now produces a relative absolute error of 0.05%.
+ print ("Leaped Estimated: %f, True Value: %f" % values)
+ ```
+
+ Args:
+ dim: Positive Python `int` representing each sample's `event_size.` Must
+ not be greater than 1000.
+ num_samples: (Optional) positive Python `int`. The number of samples to
+ generate. Either this parameter or sample_indices must be specified but
+ not both. If this parameter is None, then the behaviour is determined by
+ the `sample_indices`.
+ sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements
+ of the sequence to compute specified by their position in the sequence.
+ The entries index into the Halton sequence starting with 0 and hence,
+ must be whole numbers. For example, sample_indices=[0, 5, 6] will produce
+ the first, sixth and seventh elements of the sequence. If this parameter
+ is None, then the `num_samples` parameter must be specified which gives
+ the number of desired samples starting from the first sample.
+ dtype: (Optional) The dtype of the sample. One of `float32` or `float64`.
+ Default is `float32`.
+ name: (Optional) Python `str` describing ops managed by this function. If
+ not supplied the name of this function is used.
+
+ Returns:
+ halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype
+ and `shape` `[num_samples, dim]` if `num_samples` was specified or shape
+ `[s, dim]` where s is the size of `sample_indices` if `sample_indices`
+ were specified.
+
+ Raises:
+ ValueError: if both `sample_indices` and `num_samples` were specified or
+ if dimension `dim` is less than 1 or greater than 1000.
+ """
+ if dim < 1 or dim > _MAX_DIMENSION:
+ raise ValueError(
+ 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION,
+ dim))
+ if (num_samples is None) == (sample_indices is None):
+ raise ValueError('Either `num_samples` or `sample_indices` must be'
+ ' specified but not both.')
+
+ dtype = dtype or dtypes.float32
+ if not dtype.is_floating:
+ raise ValueError('dtype must be of `float`-type')
+
+ with ops.name_scope(name, 'sample', values=[sample_indices]):
+ # Here and in the following, the shape layout is as follows:
+ # [sample dimension, event dimension, coefficient dimension].
+ # The coefficient dimension is an intermediate axes which will hold the
+ # weights of the starting integer when expressed in the (prime) base for
+ # an event dimension.
+ indices = _get_indices(num_samples, sample_indices, dtype)
+ radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1])
+
+ max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices),
+ radixes)
+
+ max_size = math_ops.reduce_max(max_sizes_by_axes)
+
+ # The powers of the radixes that we will need. Note that there is a bit
+ # of an excess here. Suppose we need the place value coefficients of 7
+ # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits
+ # for base 3. However, we can only create rectangular tensors so we
+ # store both expansions in a [2, 3] tensor. This leads to the problem that
+ # we might end up attempting to raise large numbers to large powers. For
+ # example, base 2 expansion of 1024 has 10 digits. If we were in 10
+ # dimensions, then the 10th prime (29) we will end up computing 29^10 even
+ # though we don't need it. We avoid this by setting the exponents for each
+ # axes to 0 beyond the maximum value needed for that dimension.
+ exponents_by_axes = array_ops.tile([math_ops.range(max_size)], [dim, 1])
+ weight_mask = exponents_by_axes > max_sizes_by_axes
+ capped_exponents = array_ops.where(
+ weight_mask, array_ops.zeros_like(exponents_by_axes), exponents_by_axes)
+ weights = radixes ** capped_exponents
+ coeffs = math_ops.floor_div(indices, weights)
+ coeffs *= 1 - math_ops.cast(weight_mask, dtype)
+ coeffs = (coeffs % radixes) / radixes
+ return math_ops.reduce_sum(coeffs / weights, axis=-1)
+
+
+def _get_indices(n, sample_indices, dtype, name=None):
+ """Generates starting points for the Halton sequence procedure.
+
+ The k'th element of the sequence is generated starting from a positive integer
+ which must be distinct for each `k`. It is conventional to choose the starting
+ point as `k` itself (or `k+1` if k is zero based). This function generates
+ the starting integers for the required elements and reshapes the result for
+ later use.
+
+ Args:
+ n: Positive `int`. The number of samples to generate. If this
+ parameter is supplied, then `sample_indices` should be None.
+ sample_indices: `Tensor` of dtype int32 and rank 1. The entries
+ index into the Halton sequence starting with 0 and hence, must be whole
+ numbers. For example, sample_indices=[0, 5, 6] will produce the first,
+ sixth and seventh elements of the sequence. If this parameter is not None
+ then `n` must be None.
+ dtype: The dtype of the sample. One of `float32` or `float64`.
+ Default is `float32`.
+ name: Python `str` name which describes ops created by this function.
+
+ Returns:
+ indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`.
+ """
+ with ops.name_scope(name, 'get_indices', [n, sample_indices]):
+ if sample_indices is None:
+ sample_indices = math_ops.range(n, dtype=dtype)
+ else:
+ sample_indices = math_ops.cast(sample_indices, dtype)
+
+ # Shift the indices so they are 1 based.
+ indices = sample_indices + 1
+
+ # Reshape to make space for the event dimension and the place value
+ # coefficients.
+ return array_ops.reshape(indices, [-1, 1, 1])
+
+
+def _base_expansion_size(num, bases):
+ """Computes the number of terms in the place value expansion.
+
+ Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of
+ `num` in base b (ak <> 0). This function computes and returns `k` for each
+ base `b` specified in `bases`.
+
+ This can be inferred from the base `b` logarithm of `num` as follows:
+ $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$
+
+ Args:
+ num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to
+ compute the base expansion size of.
+ bases: `Tensor` of the same dtype as num. The bases to compute the size
+ against.
+
+ Returns:
+ Tensor of same dtype and shape as `bases` containing the size of num when
+ written in that base.
+ """
+ return math_ops.floor(math_ops.log(num) / math_ops.log(bases)) + 1
+
+
+def _primes_less_than(n):
+ # Based on
+ # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
+ """Returns sorted array of primes such that `2 <= prime < n`."""
+ small_primes = np.array((2, 3, 5))
+ if n <= 6:
+ return small_primes[small_primes < n]
+ sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool)
+ sieve[0] = False
+ m = int(n ** 0.5) // 3 + 1
+ for i in range(m):
+ if not sieve[i]:
+ continue
+ k = 3 * i + 1 | 1
+ sieve[k ** 2 // 3::2 * k] = False
+ sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False
+ return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1]
+
+_PRIMES = _primes_less_than(7919+1)
+
+assert len(_PRIMES) == _MAX_DIMENSION