aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2018-01-24 09:49:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 09:54:08 -0800
commit7b62a71e2d46c148df7d5704972f4592bc5e0f1b (patch)
tree841cd6d0bef3617bf68523907c8f79b396ef093a
parentffdae0a35785337bd10fed289d3998ba0e7c014b (diff)
* BUGFIX: See code associated with scale_identity_multiplier
* BUGFIX: See code associated with 'weight' inside sample_n * Use parameterization `temperature` rather than `mix_scale`. * Simplify documentation and link to arXiv paper for details * Document that we allow `temperature` to have any shape broadcastable with `mix_loc`. This is a mute point to some degree since we require K = 2 now. * Add some tests PiperOrigin-RevId: 183098275
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py148
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py300
2 files changed, 258 insertions, 190 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
index d292b04665..04f047aa0c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
@@ -27,6 +27,8 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
from tensorflow.python.platform import test
+rng = np.random.RandomState(0)
+
class VectorDiffeomixtureTest(
test_util.VectorDistributionTestHelpers, test.TestCase):
@@ -37,7 +39,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@@ -66,7 +68,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=normal_lib.Normal(1., 1.5),
loc=[
None,
@@ -95,7 +97,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@@ -122,12 +124,39 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.01)
+ def testSampleProbConsistentBroadcastMixTwoBatchDims(self):
+ dims = 4
+ loc_1 = rng.randn(2, 3, dims).astype(np.float32)
+
+ with self.test_session() as sess:
+ vdm = vdm_lib.VectorDiffeomixture(
+ mix_loc=(rng.rand(2, 3, 1) - 0.5).astype(np.float32),
+ temperature=[1.],
+ distribution=normal_lib.Normal(0., 1.),
+ loc=[
+ None,
+ loc_1,
+ ],
+ scale=[
+ linop_identity_lib.LinearOperatorScaledIdentity(
+ num_rows=dims,
+ multiplier=[np.float32(1.1)],
+ is_positive_definite=True),
+ ] * 2,
+ validate_args=True)
+ # Ball centered at component0's mean.
+ self.run_test_sample_consistent_log_prob(
+ sess.run, vdm, radius=2., center=0., rtol=0.01)
+ # Larger ball centered at component1's mean.
+ self.run_test_sample_consistent_log_prob(
+ sess.run, vdm, radius=3., center=loc_1, rtol=0.02)
+
def testMeanCovarianceNoBatch(self):
with self.test_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
- mix_scale=[10.],
+ temperature=[1 / 10.],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([-2.]),
@@ -147,12 +176,94 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
+ def testTemperatureControlsHowMuchThisLooksLikeDiscreteMixture(self):
+ # As temperature decreases, this should approach a mixture of normals, with
+ # components at -2, 2.
+ with self.test_session() as sess:
+ dims = 1
+ vdm = vdm_lib.VectorDiffeomixture(
+ mix_loc=[0.],
+ temperature=[[2.], [1.], [0.2]],
+ distribution=normal_lib.Normal(0., 1.),
+ loc=[
+ np.float32([-2.]),
+ np.float32([2.]),
+ ],
+ scale=[
+ linop_identity_lib.LinearOperatorScaledIdentity(
+ num_rows=dims,
+ multiplier=np.float32(0.5),
+ is_positive_definite=True),
+ ] * 2, # Use the same scale for each component.
+ quadrature_size=8,
+ validate_args=True)
+
+ samps = vdm.sample(10000)
+ self.assertAllEqual((10000, 3, 1), samps.shape)
+ samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
+
+ # One characteristic of a discrete mixture (as opposed to a "smear") is
+ # that more weight is put near the component centers at -2, 2, and thus
+ # less weight is put near the origin.
+ prob_of_being_near_origin = (np.abs(samps_) < 1).mean(axis=0)
+ self.assertGreater(
+ prob_of_being_near_origin[0], prob_of_being_near_origin[1])
+ self.assertGreater(
+ prob_of_being_near_origin[1], prob_of_being_near_origin[2])
+
+ # Run this test as well, just because we can.
+ self.run_test_sample_consistent_mean_covariance(
+ sess.run, vdm, rtol=0.02, cov_rtol=0.08)
+
+ def testConcentrationLocControlsHowMuchWeightIsOnEachComponent(self):
+ with self.test_session() as sess:
+ dims = 1
+ vdm = vdm_lib.VectorDiffeomixture(
+ mix_loc=[[-1.], [0.], [1.]],
+ temperature=[0.5],
+ distribution=normal_lib.Normal(0., 1.),
+ loc=[
+ np.float32([-2.]),
+ np.float32([2.]),
+ ],
+ scale=[
+ linop_identity_lib.LinearOperatorScaledIdentity(
+ num_rows=dims,
+ multiplier=np.float32(0.5),
+ is_positive_definite=True),
+ ] * 2, # Use the same scale for each component.
+ quadrature_size=8,
+ validate_args=True)
+
+ samps = vdm.sample(10000)
+ self.assertAllEqual((10000, 3, 1), samps.shape)
+ samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
+
+ # One characteristic of putting more weight on a component is that the
+ # mean is closer to that component's mean.
+ # Get the mean for each batch member, the names signify the value of
+ # concentration for that batch member.
+ mean_neg1, mean_0, mean_1 = samps_.mean(axis=0)
+
+ # Since concentration is the concentration for component 0,
+ # concentration = -1 ==> more weight on component 1, which has mean = 2
+ # concentration = 0 ==> equal weight
+ # concentration = 1 ==> more weight on component 0, which has mean = -2
+ self.assertLess(-2, mean_1)
+ self.assertLess(mean_1, mean_0)
+ self.assertLess(mean_0, mean_neg1)
+ self.assertLess(mean_neg1, 2)
+
+ # Run this test as well, just because we can.
+ self.run_test_sample_consistent_mean_covariance(
+ sess.run, vdm, rtol=0.02, cov_rtol=0.08)
+
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
with self.test_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
- mix_scale=[10.],
+ temperature=[0.1],
distribution=normal_lib.Normal(-1., 1.5),
loc=[
np.float32([-2.]),
@@ -177,7 +288,7 @@ class VectorDiffeomixtureTest(
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
- mix_scale=[10.],
+ temperature=[0.1],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([[-2.]]),
@@ -205,7 +316,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.],
- mix_scale=[1.],
+ temperature=[0.1],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@@ -229,29 +340,6 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.005)
- # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
- # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't
- # tested that the quadrature approach well-approximates the integral.
- #
- # To that end, consider adding these tests:
- #
- # Test1: In the limit of high mix_scale, this approximates a discrete mixture,
- # and there are many discrete mixtures where we can explicitly compute
- # mean/var, etc... So test1 would choose one of those discrete mixtures and
- # show our mean/var/etc... is close to that.
- #
- # Test2: In the limit of low mix_scale, the a diffeomixture of Normal(-5, 1),
- # Normal(5, 1) should (I believe...must check) should look almost like
- # Uniform(-5, 5), and thus (i) .prob(x) should be about 1/10 for x in (-5, 5),
- # and (ii) the first few moments should approximately match that of
- # Uniform(-5, 5)
- #
- # Test3: If mix_loc is symmetric, then for any mix_scale, our
- # quadrature-based diffeomixture of Normal(-1, 1), Normal(1, 1) should have
- # mean zero, exactly.
-
- # TODO(jvdillon): Add more tests which verify broadcasting.
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index 7ce8a83fd9..0c747f8e68 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -50,20 +50,25 @@ __all__ = [
def quadrature_scheme_softmaxnormal_gauss_hermite(
- loc, scale, quadrature_size,
+ normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
"""Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex.
+ A `SoftmaxNormal` random variable `Y` may be generated via
+
+ ```
+ Y = SoftmaxCentered(X),
+ X = Normal(normal_loc, normal_scale)
+ ```
+
Note: for a given `quadrature_size`, this method is generally less accurate
than `quadrature_scheme_softmaxnormal_quantiles`.
Args:
- loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `location` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
- scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `scale` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
+ normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
+ The location parameter of the Normal used to construct the SoftmaxNormal.
+ normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
+ The scale parameter of the Normal used to construct the SoftmaxNormal.
quadrature_size: Python `int` scalar representing the number of quadrature
points.
validate_args: Python `bool`, default `False`. When `True` distribution
@@ -80,24 +85,25 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
associated with each grid point.
"""
with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite",
- [loc, scale]):
- loc = ops.convert_to_tensor(loc, name="loc")
- dt = loc.dtype.base_dtype
- scale = ops.convert_to_tensor(scale, dtype=dt, name="scale")
+ [normal_loc, normal_scale]):
+ normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
+ dt = normal_loc.dtype.base_dtype
+ normal_scale = ops.convert_to_tensor(
+ normal_scale, dtype=dt, name="normal_scale")
- loc = maybe_check_quadrature_param(loc, "loc", validate_args)
- scale = maybe_check_quadrature_param(scale, "scale", validate_args)
+ normal_scale = maybe_check_quadrature_param(
+ normal_scale, "normal_scale", validate_args)
grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size)
- grid = grid.astype(loc.dtype.as_numpy_dtype)
- probs = probs.astype(loc.dtype.as_numpy_dtype)
+ grid = grid.astype(dt.dtype.as_numpy_dtype)
+ probs = probs.astype(dt.dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
- probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype)
+ probs = ops.convert_to_tensor(probs, name="probs", dtype=dt)
grid = softmax(
-distribution_util.pad(
- (loc[..., array_ops.newaxis] +
- np.sqrt(2.) * scale[..., array_ops.newaxis] * grid),
+ (normal_loc[..., array_ops.newaxis] +
+ np.sqrt(2.) * normal_scale[..., array_ops.newaxis] * grid),
axis=-2,
front=True),
axis=-2) # shape: [B, components, deg]
@@ -106,18 +112,23 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
def quadrature_scheme_softmaxnormal_quantiles(
- loc, scale, quadrature_size,
+ normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
"""Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.
+ A `SoftmaxNormal` random variable `Y` may be generated via
+
+ ```
+ Y = SoftmaxCentered(X),
+ X = Normal(normal_loc, normal_scale)
+ ```
+
Args:
- loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `location` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
- scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `scale` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
- quadrature_size: Python scalar `int` representing the number of quadrature
+ normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
+ The location parameter of the Normal used to construct the SoftmaxNormal.
+ normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
+ The scale parameter of the Normal used to construct the SoftmaxNormal.
+ quadrature_size: Python `int` scalar representing the number of quadrature
points.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
@@ -132,15 +143,17 @@ def quadrature_scheme_softmaxnormal_quantiles(
probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
associated with each grid point.
"""
- with ops.name_scope(name, "softmax_normal_grid_and_probs", [loc, scale]):
- loc = ops.convert_to_tensor(loc, name="loc")
- dt = loc.dtype.base_dtype
- scale = ops.convert_to_tensor(scale, dtype=dt, name="scale")
+ with ops.name_scope(name, "softmax_normal_grid_and_probs",
+ [normal_loc, normal_scale]):
+ normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
+ dt = normal_loc.dtype.base_dtype
+ normal_scale = ops.convert_to_tensor(
+ normal_scale, dtype=dt, name="normal_scale")
- loc = maybe_check_quadrature_param(loc, "loc", validate_args)
- scale = maybe_check_quadrature_param(scale, "scale", validate_args)
+ normal_scale = maybe_check_quadrature_param(
+ normal_scale, "normal_scale", validate_args)
- dist = normal_lib.Normal(loc=loc, scale=scale)
+ dist = normal_lib.Normal(loc=normal_loc, scale=normal_scale)
def _get_batch_ndims():
"""Helper to get dist.batch_shape.ndims, statically if possible."""
@@ -195,114 +208,51 @@ def quadrature_scheme_softmaxnormal_quantiles(
class VectorDiffeomixture(distribution_lib.Distribution):
"""VectorDiffeomixture distribution.
- The VectorDiffeomixture is an approximation to a [compound distribution](
- https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e.,
+ A vector diffeomixture (VDM) is a distribution parameterized by a convex
+ combination of `K` component `loc` vectors, `loc[k], k = 0,...,K-1`, and `K`
+ `scale` matrices `scale[k], k = 0,..., K-1`. It approximates the following
+ [compound distribution]
+ (https://en.wikipedia.org/wiki/Compound_probability_distribution)
```none
- p(x) = int_{X} q(x | v) p(v) dv
- = lim_{Q->infty} sum{ prob[i] q(x | loc=sum_k^K lambda[k;i] loc[k],
- scale=sum_k^K lambda[k;i] scale[k])
- : i=0, ..., Q-1 }
+ p(x) = int p(x | z) p(z) dz,
+ where z is in the K-simplex, and
+ p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
```
- where `q(x | v)` is a vector version of the `distribution` argument and `p(v)`
- is a SoftmaxNormal parameterized by `mix_loc` and `mix_scale`. The
- vector-ization of `distribution` entails an affine transformation of iid
- samples from `distribution`. The `prob` term is from quadrature and
- `lambda[k] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[k])` where the
- `grid` points correspond to the `prob`s.
-
- In the non-approximation case, a draw from the mixture distribution (the
- "prior") represents the convex weights for different affine transformations.
- I.e., draw a mixing vector `v` (from the `K-1`-simplex) and let the final
- sample be: `y = (sum_k^K v[k] scale[k]) @ x + (sum_k^K v[k] loc[k])` where `@`
- denotes matrix multiplication. However, the non-approximate distribution does
- not have an analytical probability density function (pdf). Therefore the
- `VectorDiffeomixture` class implements an approximation based on
- [numerical quadrature](
- https://en.wikipedia.org/wiki/Numerical_integration) (default:
- [Gauss--Hermite quadrature](
- https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in
- Note: although the `VectorDiffeomixture` is approximately the
- `SoftmaxNormal-Distribution` compound distribution, it is itself a valid
- distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
- are all mutually consistent.
-
- #### Intended Use
-
- This distribution is noteworthy because it implements a mixture of
- `Vector`-ized distributions yet has samples differentiable in the
- distribution's parameters (aka "reparameterized"). It has an analytical
- density function with `O(dKQ)` complexity. `d` is the vector dimensionality,
- `K` is the number of components, and `Q` is the number of quadrature points.
- These properties make it well-suited for Bayesian Variational Inference, i.e.,
- as a surrogate family for the posterior.
-
- For large values of `mix_scale`, the `VectorDistribution` behaves increasingly
- like a discrete mixture. (In most cases this limit is only achievable by also
- increasing the quadrature polynomial degree, `Q`.)
-
- The term `Vector` is consistent with similar named Tensorflow `Distribution`s.
- For more details, see the "About `Vector` distributions in Tensorflow."
- section.
-
- The term `Diffeomixture` is a portmanteau of
- [diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism) and [compound
- mixture](https://en.wikipedia.org/wiki/Compound_probability_distribution). For
- more details, see the "About `Diffeomixture`s and reparametrization.`"
- section.
-
- #### Mathematical Details
-
- The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
- [compound distribution](
- https://en.wikipedia.org/wiki/Compound_probability_distribution).
- Using variable-substitution and [numerical quadrature](
- https://en.wikipedia.org/wiki/Numerical_integration) (default:
- [Gauss--Hermite quadrature](
- https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
- redefine the distribution to be a parameter-less convex combination of `K`
- different affine combinations of a `d` iid samples from `distribution`.
-
- That is, defined over `R**d` this distribution is parameterized by a
- (batch of) length-`K` `mix_loc` and `mix_scale` vectors, a length-`K` list of
- (a batch of) length-`d` `loc` vectors, and a length-`K` list of `scale`
- `LinearOperator`s each operating on a (batch of) length-`d` vector space.
- Finally, a `distribution` parameter specifies the underlying base distribution
- which is "lifted" to become multivariate ("lifting" is the same concept as in
- `TransformedDistribution`).
-
- The probability density function (pdf) is,
+ The integral `int p(x | z) p(z) dz` is approximated with a quadrature scheme
+ adapted to the mixture density `p(z)`. The `N` quadrature points `z_{N, n}`
+ and weights `w_{N, n}` (which are non-negative and sum to 1) are chosen
+ such that
- ```none
- pdf(y; mix_loc, mix_scale, loc, scale, phi)
- = sum{ prob[i] phi(f_inverse(x; i)) / abs(det(interp_scale[i]))
- : i=0, ..., Q-1 }
- ```
+ ```q_N(x) := sum_{n=1}^N w_{n, N} p(x | z_{N, n}) --> p(x)```
- where, `phi` is the base distribution pdf, and,
+ as `N --> infinity`.
- ```none
- f_inverse(x; i) = inv(interp_scale[i]) @ (x - interp_loc[i])
- interp_loc[i] = sum{ lambda[k; i] loc[k] : k=0, ..., K-1 }
- interp_scale[i] = sum{ lambda[k; i] scale[k] : k=0, ..., K-1 }
- ```
+ Since `q_N(x)` is in fact a mixture (of `N` points), we may sample from
+ `q_N` exactly. It is important to note that the VDM is *defined* as `q_N`
+ above, and *not* `p(x)`. Therefore, sampling and pdf may be implemented as
+ exact (up to floating point error) methods.
- and,
+ A common choice for the conditional `p(x | z)` is a multivariate Normal.
- ```none
- grid, weight = np.polynomial.hermite.hermgauss(quadrature_size)
- prob[k] = weight[k] / sqrt(pi)
- lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
+ The implemented marginal `p(z)` is the `SoftmaxNormal`, which is a
+ `K-1` dimensional Normal transformed by a `SoftmaxCentered` bijector, making
+ it a density on the `K`-simplex. That is,
+
+ ```
+ Z = SoftmaxCentered(X),
+ X = Normal(mix_loc / temperature, 1 / temperature)
```
- The distribution corresponding to `phi` must be a scalar-batch, scalar-event
- distribution. Typically it is reparameterized. If not, it must be a function
- of non-trainable parameters.
+ The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of
+ the quantiles of `p(z)` (generalized quantiles if `K > 2`).
- WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
- distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
- variables, then the gradient is not guaranteed correct!
+ See [1] for more details.
+
+ [1]. "Quadrature Compound: An approximating family of distributions"
+ Joshua Dillon, Ian Langmore, arXiv preprints
+ https://arxiv.org/abs/1801.03080
#### About `Vector` distributions in TensorFlow.
@@ -310,12 +260,11 @@ class VectorDiffeomixture(distribution_lib.Distribution):
particularly useful in [variational Bayesian
methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods).
- Conditioned on a draw from the SoftmaxNormal, `Y|v` is a vector whose
+ Conditioned on a draw from the SoftmaxNormal, `X|z` is a vector whose
components are linear combinations of affine transformations, thus is itself
- an affine transformation. Therefore `Y|v` lives in the vector space generated
- by vectors of affine-transformed distributions.
+ an affine transformation.
- Note: The marginals `Y_1|v, ..., Y_d|v` are *not* generally identical to some
+ Note: The marginals `X_1|v, ..., X_d|v` are *not* generally identical to some
parameterization of `distribution`. This is due to the fact that the sum of
draws from `distribution` are not generally itself the same `distribution`.
@@ -331,12 +280,16 @@ class VectorDiffeomixture(distribution_lib.Distribution):
optimize Monte-Carlo objectives. Such objectives are a finite-sample
approximation of an expectation and arise throughout scientific computing.
+ WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
+ distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
+ variables, then the gradient is not guaranteed correct!
+
#### Examples
```python
tfd = tf.contrib.distributions
- # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and
+ # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
# another with mix_loc=[1]. In both cases, `K=2` and the affine
# transformations involve:
# k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity
@@ -344,7 +297,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
dims = 5
vdm = tfd.VectorDiffeomixture(
mix_loc=[[0.], [1]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=tfd.Normal(loc=0., scale=1.),
loc=[
None, # Equivalent to `np.zeros(dims, dtype=np.float32)`.
@@ -364,7 +317,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
def __init__(self,
mix_loc,
- mix_scale,
+ temperature,
distribution,
loc=None,
scale=None,
@@ -373,15 +326,24 @@ class VectorDiffeomixture(distribution_lib.Distribution):
validate_args=False,
allow_nan_stats=True,
name="VectorDiffeomixture"):
- """Constructs the VectorDiffeomixture on `R**d`.
+ """Constructs the VectorDiffeomixture on `R^d`.
+
+ The vector diffeomixture (VDM) approximates the compound distribution
+
+ ```none
+ p(x) = int p(x | z) p(z) dz,
+ where z is in the K-simplex, and
+ p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
+ ```
Args:
- mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. Represents
- the `location` parameter of the SoftmaxNormal used for selecting one of
- the `K` affine transformations.
- mix_scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
- Represents the `scale` parameter of the SoftmaxNormal used for selecting
- one of the `K` affine transformations.
+ mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
+ In terms of samples, larger `mix_loc[..., k]` ==>
+ `Z` is more likely to put more weight on its `kth` component.
+ temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
+ In terms of samples, smaller `temperature` means one component is more
+ likely to dominate. I.e., smaller `temperature` makes the VDM look more
+ like a standard mixture of `K` components.
distribution: `tf.Distribution`-like instance. Distribution from which `d`
iid samples are used as input to the selected affine transformation.
Must be a scalar-batch, scalar-event distribution. Typically
@@ -401,8 +363,9 @@ class VectorDiffeomixture(distribution_lib.Distribution):
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
quadrature_size: Python `int` scalar representing number of
- quadrature points.
- quadrature_fn: Python callable taking `mix_loc`, `mix_scale`,
+ quadrature points. Larger `quadrature_size` means `q_N(x)` better
+ approximates `p(x)`.
+ quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
`quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
representing the SoftmaxNormal grid and corresponding normalized weight.
normalized) weight.
@@ -430,7 +393,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
ValueError: if `not distribution.is_scalar_event`.
"""
parameters = locals()
- with ops.name_scope(name, values=[mix_loc, mix_scale]):
+ with ops.name_scope(name, values=[mix_loc, temperature]):
if not scale or len(scale) < 2:
raise ValueError("Must specify list (or list-like object) of scale "
"LinearOperators, one for each component with "
@@ -473,8 +436,15 @@ class VectorDiffeomixture(distribution_lib.Distribution):
raise NotImplementedError("Currently only bimixtures are supported; "
"len(scale)={} is not 2.".format(len(scale)))
+ mix_loc = ops.convert_to_tensor(
+ mix_loc, dtype=dtype, name="mix_loc")
+ temperature = ops.convert_to_tensor(
+ temperature, dtype=dtype, name="temperature")
self._grid, probs = tuple(quadrature_fn(
- mix_loc, mix_scale, quadrature_size, validate_args))
+ mix_loc / temperature,
+ 1. / temperature,
+ quadrature_size,
+ validate_args))
# Note: by creating the logits as `log(prob)` we ensure that
# `self.mixture_distribution.logits` is equivalent to
@@ -618,7 +588,14 @@ class VectorDiffeomixture(distribution_lib.Distribution):
weight = array_ops.gather(
array_ops.reshape(self.grid, shape=[-1]),
ids + offset)
- weight = weight[..., array_ops.newaxis]
+ # At this point, weight flattened all batch dims into one.
+ # We also need to append a singleton to broadcast with event dims.
+ if self.batch_shape.is_fully_defined():
+ new_shape = [-1] + self.batch_shape.as_list() + [1]
+ else:
+ new_shape = array_ops.concat(
+ ([-1], self.batch_shape_tensor(), [1]), axis=0)
+ weight = array_ops.reshape(weight, shape=new_shape)
if len(x) != 2:
# We actually should have already triggered this exception. However as a
@@ -686,7 +663,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
# To compute E[Cov(Z|V)], we'll add matrices within three categories:
# scaled-identity, diagonal, and full. Then we'll combine these at the end.
- scaled_identity = None
+ scale_identity_multiplier = None
diag = None
full = None
@@ -694,10 +671,12 @@ class VectorDiffeomixture(distribution_lib.Distribution):
s = aff.scale # Just in case aff.scale has side-effects, we'll call once.
if (s is None
or isinstance(s, linop_identity_lib.LinearOperatorIdentity)):
- scaled_identity = add(scaled_identity, p[..., k, array_ops.newaxis])
+ scale_identity_multiplier = add(scale_identity_multiplier,
+ p[..., k, array_ops.newaxis])
elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity):
- scaled_identity = add(scaled_identity, (p[..., k, array_ops.newaxis] *
- math_ops.square(s.multiplier)))
+ scale_identity_multiplier = add(
+ scale_identity_multiplier,
+ (p[..., k, array_ops.newaxis] * math_ops.square(s.multiplier)))
elif isinstance(s, linop_diag_lib.LinearOperatorDiag):
diag = add(diag, (p[..., k, array_ops.newaxis] *
math_ops.square(s.diag_part())))
@@ -709,12 +688,13 @@ class VectorDiffeomixture(distribution_lib.Distribution):
full = add(full, x)
# We must now account for the fact that the base distribution might have a
- # non-unity variance. Recall that `Cov(SX+m) = S.T Cov(X) S = S.T S Var(X)`.
+ # non-unity variance. Recall that, since X ~ iid Law(X_0),
+ # `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`.
# We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid
# samples from a scalar-event distribution.
v = self.distribution.variance()
- if scaled_identity is not None:
- scaled_identity *= v
+ if scale_identity_multiplier is not None:
+ scale_identity_multiplier *= v
if diag is not None:
diag *= v[..., array_ops.newaxis]
if full is not None:
@@ -723,10 +703,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
if diag_only:
# Apparently we don't need the full matrix, just the diagonal.
r = add(diag, full)
- if r is None and scaled_identity is not None:
+ if r is None and scale_identity_multiplier is not None:
ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype)
- return scaled_identity * ones
- return add(r, scaled_identity)
+ return scale_identity_multiplier[..., array_ops.newaxis] * ones
+ return add(r, scale_identity_multiplier)
# `None` indicates we don't know if the result is positive-definite.
is_positive_definite = (True if all(aff.scale.is_positive_definite
@@ -742,10 +722,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
to_add.append(linop_full_lib.LinearOperatorFullMatrix(
matrix=full,
is_positive_definite=is_positive_definite))
- if scaled_identity is not None:
+ if scale_identity_multiplier is not None:
to_add.append(linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=self.event_shape_tensor()[0],
- multiplier=scaled_identity,
+ multiplier=scale_identity_multiplier,
is_positive_definite=is_positive_definite))
return (linop_add_lib.add_operators(to_add)[0].to_dense()