diff options
author | 2017-10-18 10:49:20 -0700 | |
---|---|---|
committer | 2017-10-18 10:53:13 -0700 | |
commit | c9d3377fb4f973e1592ebc71862e02dacf5f3a4f (patch) | |
tree | 96029aeade622183af2195e00cc6dc1d90941cea | |
parent | 3c31886537a8b5fb5ab62b4b925f8ef044960ca3 (diff) |
Make `tf.contrib.distributions` quadrature family parameterized by
`quadrature_grid_and_prob` vs `quadrature_degree`. Enables support of
quadrature methods other than Gauss-Hermite.
PiperOrigin-RevId: 172622919
3 files changed, 79 insertions, 50 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py index 7cb46bb236..3ded4159d8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import poisson_lognormal from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.python.platform import test @@ -32,7 +34,8 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=-2., scale=1.1, - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( sess, pln, rtol=0.1) @@ -42,7 +45,8 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=0., scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( sess, pln, rtol=0.02) @@ -52,7 +56,8 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( sess, pln, rtol=0.1, atol=0.01) @@ -62,7 +67,8 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( sess, pln, rtol=0.1, atol=0.01) @@ -72,7 +78,8 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[[0.], [-0.5]], scale=[[1., 0.9]], - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( sess, pln, rtol=0.1, atol=0.08) @@ -82,7 +89,8 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[[0.], [-0.5]], scale=[[1., 0.9]], - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( sess, pln, rtol=0.1, atol=0.01) diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 65ee3a16d6..80d4e2dc5e 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -93,7 +93,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): : d=0, ..., deg-1 } ``` - where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( + where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html) and `prob = w / sqrt(pi)`. @@ -106,14 +106,15 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): pln = ds.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) """ def __init__(self, loc, scale, - quadrature_polynomial_degree=8, + quadrature_grid_and_probs=None, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): @@ -124,8 +125,9 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. - quadrature_polynomial_degree: Python `int`-like scalar. - Default value: 8. + quadrature_grid_and_probs: Python pair of `list`-like objects representing + the sample points and the corresponding (possibly normalized) weight. + When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -138,6 +140,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): Raises: TypeError: if `loc.dtype != scale[0].dtype`. + ValueError: if `quadrature_grid_and_probs is not None` and + `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): @@ -153,18 +157,21 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( loc.dtype.name, scale.dtype.name)) - self._degree = quadrature_polynomial_degree - - grid, prob = np.polynomial.hermite.hermgauss( - deg=quadrature_polynomial_degree) - - # It should be that `sum(prob) == sqrt(pi)`, but self-normalization is - # more numerically stable. - prob = prob.astype(dtype.as_numpy_dtype) - prob /= np.linalg.norm(prob, ord=1) + if quadrature_grid_and_probs is None: + grid, probs = np.polynomial.hermite.hermgauss(deg=8) + else: + grid, probs = tuple(quadrature_grid_and_probs) + if len(grid) != len(probs): + raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " + "same-length list-like objects") + grid = grid.astype(dtype.as_numpy_dtype) + probs = probs.astype(dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1) + self._quadrature_grid = grid + self._quadrature_probs = probs self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(prob), + logits=np.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -210,9 +217,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): return self._scale @property - def quadrature_polynomial_degree(self): - """Polynomial largest exponent used for Gauss-Hermite quadrature.""" - return self._degree + def quadrature_grid(self): + """Quadrature grid points.""" + return self._quadrature_grid + + @property + def quadrature_probs(self): + """Quadrature normalized weights.""" + return self._quadrature_probs def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( @@ -242,10 +254,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) - # Stride `quadrature_polynomial_degree` for `batch_size` number of times. + # Stride `quadrature_degree` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * self._degree, - delta=self._degree, + limit=batch_size * len(self.quadrature_probs), + delta=len(self.quadrature_probs), dtype=ids.dtype) ids += offset rate = array_ops.gather( diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 438d628da4..33dad811a9 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -141,7 +141,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): and, ```none - grid, weight = np.polynomial.hermite.hermgauss(quadrature_polynomial_degree) + grid, weight = np.polynomial.hermite.hermgauss(quadrature_degree) prob[k] = weight[k] / sqrt(pi) lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i]) ``` @@ -219,7 +219,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): distribution, loc=None, scale=None, - quadrature_polynomial_degree=8, + quadrature_grid_and_probs=None, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): @@ -248,7 +248,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices - quadrature_polynomial_degree: Python `int`-like scalar. + quadrature_grid_and_probs: Python pair of `list`-like objects representing + the sample points and the corresponding (possibly normalized) weight. + When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -262,7 +264,8 @@ class VectorDiffeomixture(distribution_lib.Distribution): Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` - ValueError: if `quadrature_polynomial_degree < 1`. + ValueError: if `quadrature_grid_and_probs is not None` and + `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. @@ -307,12 +310,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale))] - if quadrature_polynomial_degree < 1: - raise ValueError("quadrature_polynomial_degree={} " - "is not at least 1".format( - quadrature_polynomial_degree)) - self._degree = quadrature_polynomial_degree - # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. @@ -320,17 +317,24 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - grid, prob = np.polynomial.hermite.hermgauss( - deg=quadrature_polynomial_degree) + if quadrature_grid_and_probs is None: + grid, probs = np.polynomial.hermite.hermgauss(deg=8) + else: + grid, probs = tuple(quadrature_grid_and_probs) + if len(grid) != len(probs): + raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " + "same-length list-like objects") grid = grid.astype(dtype.as_numpy_dtype) - prob = prob.astype(dtype.as_numpy_dtype) - prob /= np.linalg.norm(prob, ord=1) + probs = probs.astype(dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1) + self._quadrature_grid = grid + self._quadrature_probs = probs # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(prob), + logits=np.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -357,10 +361,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( - interpolate_loc(quadrature_polynomial_degree, + interpolate_loc(len(self._quadrature_grid), self._interpolate_weight, loc), - interpolate_scale(quadrature_polynomial_degree, + interpolate_scale(len(self._quadrature_grid), self._interpolate_weight, scale)))] @@ -416,9 +420,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return self._interpolated_affine @property - def quadrature_polynomial_degree(self): - """Polynomial largest exponent used for Gauss-Hermite quadrature.""" - return self._degree + def quadrature_grid(self): + """Quadrature grid points.""" + return self._quadrature_grid + + @property + def quadrature_probs(self): + """Quadrature normalized weights.""" + return self._quadrature_probs def _batch_shape_tensor(self): return self._batch_shape_ @@ -454,10 +463,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) - # Stride `self._degree` for `batch_size` number of times. + # Stride `quadrature_degree` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * self._degree, - delta=self._degree, + limit=batch_size * len(self.quadrature_probs), + delta=len(self.quadrature_probs), dtype=ids.dtype) weight = array_ops.gather( |