aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distributions/python/ops/transformed_distribution.py')
-rw-r--r--tensorflow/contrib/distributions/python/ops/transformed_distribution.py76
1 files changed, 51 insertions, 25 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
index 1403adbda2..844f78ca96 100644
--- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
@@ -19,11 +19,9 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.distributions.python.ops import distribution as distribution_lib
+from tensorflow.contrib.distributions.python.ops import distribution as distributions
from tensorflow.contrib.distributions.python.ops import distribution_util
-# Bijectors must be directly imported because `remove_undocumented` prevents
-# individual file imports.
-from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity
+from tensorflow.contrib.distributions.python.ops.bijectors import identity as identity_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -121,7 +119,7 @@ def _is_scalar_from_shape(shape):
return _logical_equal(_ndims_from_shape(shape), 0)
-class TransformedDistribution(distribution_lib.Distribution):
+class TransformedDistribution(distributions.Distribution):
"""A Transformed Distribution.
A `TransformedDistribution` models `p(y)` given a base distribution `p(x)`,
@@ -148,19 +146,49 @@ class TransformedDistribution(distribution_lib.Distribution):
A `TransformedDistribution` implements the following operations:
- * `sample`
- Mathematically: `Y = g(X)`
- Programmatically: `bijector.forward(distribution.sample(...))`
+ * `sample`:
- * `log_prob`
- Mathematically: `(log o pdf)(Y=y) = (log o pdf o g^{-1})(y)
- + (log o abs o det o J o g^{-1})(y)`
- Programmatically: `(distribution.log_prob(bijector.inverse(y))
- + bijector.inverse_log_det_jacobian(y))`
+ Mathematically:
- * `log_cdf`
- Mathematically: `(log o cdf)(Y=y) = (log o cdf o g^{-1})(y)`
- Programmatically: `distribution.log_cdf(bijector.inverse(x))`
+ ```none
+ Y = g(X)
+ ```
+
+ Programmatically:
+
+ ```python
+ return bijector.forward(distribution.sample(...))
+ ```
+
+ * `log_prob`:
+
+ Mathematically:
+
+ ```none
+ (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) +
+ (log o abs o det o J o g^{-1})(y)
+ ```
+
+ Programmatically:
+
+ ```python
+ return (distribution.log_prob(bijector.inverse(y)) +
+ bijector.inverse_log_det_jacobian(y))
+ ```
+
+ * `log_cdf`:
+
+ Mathematically:
+
+ ```none
+ (log o cdf)(Y=y) = (log o cdf o g^{-1})(y)
+ ```
+
+ Programmatically:
+
+ ```python
+ return distribution.log_cdf(bijector.inverse(x))
+ ```
* and similarly for: `cdf`, `prob`, `log_survival_function`,
`survival_function`.
@@ -171,7 +199,7 @@ class TransformedDistribution(distribution_lib.Distribution):
```python
ds = tf.contrib.distributions
log_normal = ds.TransformedDistribution(
- distribution=ds.Normal(loc=0., scale=1.),
+ distribution=ds.Normal(loc=mu, scale=sigma),
bijector=ds.bijectors.Exp(),
name="LogNormalTransformedDistribution")
```
@@ -181,7 +209,7 @@ class TransformedDistribution(distribution_lib.Distribution):
```python
ds = tf.contrib.distributions
log_normal = ds.TransformedDistribution(
- distribution=ds.Normal(loc=0., scale=1.),
+ distribution=ds.Normal(loc=mu, scale=sigma),
bijector=ds.bijectors.Inline(
forward_fn=tf.exp,
inverse_fn=tf.log,
@@ -195,11 +223,8 @@ class TransformedDistribution(distribution_lib.Distribution):
```python
ds = tf.contrib.distributions
normal = ds.TransformedDistribution(
- distribution=ds.Normal(loc=0., scale=1.),
- bijector=ds.bijectors.Affine(
- shift=-1.,
- scale_identity_multiplier=2.,
- event_ndims=0),
+ distribution=ds.Normal(loc=0, scale=1),
+ bijector=ds.bijectors.ScaleAndShift(loc=mu, scale=sigma, event_ndims=0),
name="NormalTransformedDistribution")
```
@@ -212,6 +237,7 @@ class TransformedDistribution(distribution_lib.Distribution):
multivariate Normal as a `TransformedDistribution`.
```python
+ bs = tf.contrib.distributions.bijector
ds = tf.contrib.distributions
# We will create two MVNs with batch_shape = event_shape = 2.
mean = [[-1., 0], # batch:0
@@ -222,7 +248,7 @@ class TransformedDistribution(distribution_lib.Distribution):
[2, 2]]] # batch:1
mvn1 = ds.TransformedDistribution(
distribution=ds.Normal(loc=0., scale=1.),
- bijector=ds.bijectors.Affine(shift=mean, scale_tril=chol_cov),
+ bijector=bs.Affine(shift=mean, tril=chol_cov),
batch_shape=[2], # Valid because base_distribution.batch_shape == [].
event_shape=[2]) # Valid because base_distribution.event_shape == [].
mvn2 = ds.MultivariateNormalTriL(loc=mean, scale_tril=chol_cov)
@@ -265,7 +291,7 @@ class TransformedDistribution(distribution_lib.Distribution):
self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")
if bijector is None:
- bijector = Identity(validate_args=validate_args)
+ bijector = identity_lib.Identity(validate_args=validate_args)
# We will keep track of a static and dynamic version of
# self._is_{batch,event}_override. This way we can do more prior to graph