aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-10-18 10:49:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 10:53:13 -0700
commitc9d3377fb4f973e1592ebc71862e02dacf5f3a4f (patch)
tree96029aeade622183af2195e00cc6dc1d90941cea
parent3c31886537a8b5fb5ab62b4b925f8ef044960ca3 (diff)
Make `tf.contrib.distributions` quadrature family parameterized by
`quadrature_grid_and_prob` vs `quadrature_degree`. Enables support of quadrature methods other than Gauss-Hermite. PiperOrigin-RevId: 172622919
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py20
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py54
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py55
3 files changed, 79 insertions, 50 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py
index 7cb46bb236..3ded4159d8 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.distributions.python.ops import poisson_lognormal
from tensorflow.contrib.distributions.python.ops import test_util
from tensorflow.python.platform import test
@@ -32,7 +34,8 @@ class PoissonLogNormalQuadratureCompoundTest(
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=-2.,
scale=1.1,
- quadrature_polynomial_degree=10,
+ quadrature_grid_and_probs=(
+ np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1)
@@ -42,7 +45,8 @@ class PoissonLogNormalQuadratureCompoundTest(
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=0.,
scale=1.,
- quadrature_polynomial_degree=10,
+ quadrature_grid_and_probs=(
+ np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.02)
@@ -52,7 +56,8 @@ class PoissonLogNormalQuadratureCompoundTest(
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
- quadrature_polynomial_degree=10,
+ quadrature_grid_and_probs=(
+ np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.01)
@@ -62,7 +67,8 @@ class PoissonLogNormalQuadratureCompoundTest(
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
- quadrature_polynomial_degree=10,
+ quadrature_grid_and_probs=(
+ np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
@@ -72,7 +78,8 @@ class PoissonLogNormalQuadratureCompoundTest(
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[[0.], [-0.5]],
scale=[[1., 0.9]],
- quadrature_polynomial_degree=10,
+ quadrature_grid_and_probs=(
+ np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.08)
@@ -82,7 +89,8 @@ class PoissonLogNormalQuadratureCompoundTest(
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[[0.], [-0.5]],
scale=[[1., 0.9]],
- quadrature_polynomial_degree=10,
+ quadrature_grid_and_probs=(
+ np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 65ee3a16d6..80d4e2dc5e 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -93,7 +93,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
: d=0, ..., deg-1 }
```
- where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`](
+ where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`](
https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html)
and `prob = w / sqrt(pi)`.
@@ -106,14 +106,15 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
pln = ds.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
- quadrature_polynomial_degree=10,
+ quadrature_grid_and_probs=(
+ np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
"""
def __init__(self,
loc,
scale,
- quadrature_polynomial_degree=8,
+ quadrature_grid_and_probs=None,
validate_args=False,
allow_nan_stats=True,
name="PoissonLogNormalQuadratureCompound"):
@@ -124,8 +125,9 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
the LogNormal prior.
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
the LogNormal prior.
- quadrature_polynomial_degree: Python `int`-like scalar.
- Default value: 8.
+ quadrature_grid_and_probs: Python pair of `list`-like objects representing
+ the sample points and the corresponding (possibly normalized) weight.
+ When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@@ -138,6 +140,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
Raises:
TypeError: if `loc.dtype != scale[0].dtype`.
+ ValueError: if `quadrature_grid_and_probs is not None` and
+ `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
"""
parameters = locals()
with ops.name_scope(name, values=[loc, scale]):
@@ -153,18 +157,21 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
loc.dtype.name, scale.dtype.name))
- self._degree = quadrature_polynomial_degree
-
- grid, prob = np.polynomial.hermite.hermgauss(
- deg=quadrature_polynomial_degree)
-
- # It should be that `sum(prob) == sqrt(pi)`, but self-normalization is
- # more numerically stable.
- prob = prob.astype(dtype.as_numpy_dtype)
- prob /= np.linalg.norm(prob, ord=1)
+ if quadrature_grid_and_probs is None:
+ grid, probs = np.polynomial.hermite.hermgauss(deg=8)
+ else:
+ grid, probs = tuple(quadrature_grid_and_probs)
+ if len(grid) != len(probs):
+ raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
+ "same-length list-like objects")
+ grid = grid.astype(dtype.as_numpy_dtype)
+ probs = probs.astype(dtype.as_numpy_dtype)
+ probs /= np.linalg.norm(probs, ord=1)
+ self._quadrature_grid = grid
+ self._quadrature_probs = probs
self._mixture_distribution = categorical_lib.Categorical(
- logits=np.log(prob),
+ logits=np.log(probs),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)
@@ -210,9 +217,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
return self._scale
@property
- def quadrature_polynomial_degree(self):
- """Polynomial largest exponent used for Gauss-Hermite quadrature."""
- return self._degree
+ def quadrature_grid(self):
+ """Quadrature grid points."""
+ return self._quadrature_grid
+
+ @property
+ def quadrature_probs(self):
+ """Quadrature normalized weights."""
+ return self._quadrature_probs
def _batch_shape_tensor(self):
return array_ops.broadcast_dynamic_shape(
@@ -242,10 +254,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
[batch_size])),
seed=distribution_util.gen_new_seed(
seed, "poisson_lognormal_quadrature_compound"))
- # Stride `quadrature_polynomial_degree` for `batch_size` number of times.
+ # Stride `quadrature_degree` for `batch_size` number of times.
offset = math_ops.range(start=0,
- limit=batch_size * self._degree,
- delta=self._degree,
+ limit=batch_size * len(self.quadrature_probs),
+ delta=len(self.quadrature_probs),
dtype=ids.dtype)
ids += offset
rate = array_ops.gather(
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index 438d628da4..33dad811a9 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -141,7 +141,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
and,
```none
- grid, weight = np.polynomial.hermite.hermgauss(quadrature_polynomial_degree)
+ grid, weight = np.polynomial.hermite.hermgauss(quadrature_degree)
prob[k] = weight[k] / sqrt(pi)
lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
```
@@ -219,7 +219,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
distribution,
loc=None,
scale=None,
- quadrature_polynomial_degree=8,
+ quadrature_grid_and_probs=None,
validate_args=False,
allow_nan_stats=True,
name="VectorDiffeomixture"):
@@ -248,7 +248,9 @@ class VectorDiffeomixture(distribution_lib.Distribution):
`k`-th element represents the `scale` used for the `k`-th affine
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
- quadrature_polynomial_degree: Python `int`-like scalar.
+ quadrature_grid_and_probs: Python pair of `list`-like objects representing
+ the sample points and the corresponding (possibly normalized) weight.
+ When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@@ -262,7 +264,8 @@ class VectorDiffeomixture(distribution_lib.Distribution):
Raises:
ValueError: if `not scale or len(scale) < 2`.
ValueError: if `len(loc) != len(scale)`
- ValueError: if `quadrature_polynomial_degree < 1`.
+ ValueError: if `quadrature_grid_and_probs is not None` and
+ `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
ValueError: if `validate_args` and any not scale.is_positive_definite.
TypeError: if any scale.dtype != scale[0].dtype.
TypeError: if any loc.dtype != scale[0].dtype.
@@ -307,12 +310,6 @@ class VectorDiffeomixture(distribution_lib.Distribution):
name="endpoint_affine_{}".format(k))
for k, (loc_, scale_) in enumerate(zip(loc, scale))]
- if quadrature_polynomial_degree < 1:
- raise ValueError("quadrature_polynomial_degree={} "
- "is not at least 1".format(
- quadrature_polynomial_degree))
- self._degree = quadrature_polynomial_degree
-
# TODO(jvdillon): Remove once we support k-mixtures.
# We make this assertion here because otherwise `grid` would need to be a
# vector not a scalar.
@@ -320,17 +317,24 @@ class VectorDiffeomixture(distribution_lib.Distribution):
raise NotImplementedError("Currently only bimixtures are supported; "
"len(scale)={} is not 2.".format(len(scale)))
- grid, prob = np.polynomial.hermite.hermgauss(
- deg=quadrature_polynomial_degree)
+ if quadrature_grid_and_probs is None:
+ grid, probs = np.polynomial.hermite.hermgauss(deg=8)
+ else:
+ grid, probs = tuple(quadrature_grid_and_probs)
+ if len(grid) != len(probs):
+ raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
+ "same-length list-like objects")
grid = grid.astype(dtype.as_numpy_dtype)
- prob = prob.astype(dtype.as_numpy_dtype)
- prob /= np.linalg.norm(prob, ord=1)
+ probs = probs.astype(dtype.as_numpy_dtype)
+ probs /= np.linalg.norm(probs, ord=1)
+ self._quadrature_grid = grid
+ self._quadrature_probs = probs
# Note: by creating the logits as `log(prob)` we ensure that
# `self.mixture_distribution.logits` is equivalent to
# `math_ops.log(self.mixture_distribution.probs)`.
self._mixture_distribution = categorical_lib.Categorical(
- logits=np.log(prob),
+ logits=np.log(probs),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)
@@ -357,10 +361,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
validate_args=validate_args,
name="interpolated_affine_{}".format(k))
for k, (loc_, scale_) in enumerate(zip(
- interpolate_loc(quadrature_polynomial_degree,
+ interpolate_loc(len(self._quadrature_grid),
self._interpolate_weight,
loc),
- interpolate_scale(quadrature_polynomial_degree,
+ interpolate_scale(len(self._quadrature_grid),
self._interpolate_weight,
scale)))]
@@ -416,9 +420,14 @@ class VectorDiffeomixture(distribution_lib.Distribution):
return self._interpolated_affine
@property
- def quadrature_polynomial_degree(self):
- """Polynomial largest exponent used for Gauss-Hermite quadrature."""
- return self._degree
+ def quadrature_grid(self):
+ """Quadrature grid points."""
+ return self._quadrature_grid
+
+ @property
+ def quadrature_probs(self):
+ """Quadrature normalized weights."""
+ return self._quadrature_probs
def _batch_shape_tensor(self):
return self._batch_shape_
@@ -454,10 +463,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
seed=distribution_util.gen_new_seed(
seed, "vector_diffeomixture"))
- # Stride `self._degree` for `batch_size` number of times.
+ # Stride `quadrature_degree` for `batch_size` number of times.
offset = math_ops.range(start=0,
- limit=batch_size * self._degree,
- delta=self._degree,
+ limit=batch_size * len(self.quadrature_probs),
+ delta=len(self.quadrature_probs),
dtype=ids.dtype)
weight = array_ops.gather(