aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py64
1 files changed, 57 insertions, 7 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 1ef7651d03..eb94760ad7 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -128,7 +128,7 @@ The base distribution's `log_cdf` method must be defined on `y - 1`.
class QuantizedDistribution(distributions.Distribution):
"""Distribution representing the quantization `Y = ceiling(X)`.
- #### Definition in terms of sampling.
+ #### Definition in Terms of Sampling
```
1. Draw X
@@ -138,7 +138,7 @@ class QuantizedDistribution(distributions.Distribution):
5. Return Y
```
- #### Definition in terms of the probability mass function.
+ #### Definition in Terms of the Probability Mass Function
Given scalar random variable `X`, we define a discrete random variable `Y`
supported on the integers as follows:
@@ -170,12 +170,62 @@ class QuantizedDistribution(distributions.Distribution):
`P[Y = j]` is still the mass of `X` within the `jth` interval.
- #### Caveats
+ #### Examples
+
+ We illustrate a mixture of discretized logistic distributions
+ [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit
+ audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in
+ a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures
+ `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints.
+ The lowest value has probability `P(X <= 0.5)` and the highest value has
+ probability `P(2**16 - 1.5 < X)`.
+
+ Below we assume a `wavenet` function. It takes as `input` right-shifted audio
+ samples of shape `[..., sequence_length]`. It returns a real-valued tensor of
+ shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and
+ `scale` parameter belonging to the logistic distribution, and a `logits`
+ parameter determining the unnormalized probability of that component.
+
+ ```python
+ tfd = tf.contrib.distributions
+ tfb = tfd.bijectors
+
+ net = wavenet(inputs)
+ loc, unconstrained_scale, logits = tf.split(net,
+ num_or_size_splits=3,
+ axis=-1)
+ scale = tf.nn.softplus(unconstrained_scale)
+
+ # Form mixture of discretized logistic distributions. Note we shift the
+ # logistic distribution by -0.5. This lets the quantization capture "rounding"
+ # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
+ discretized_logistic_dist = tfd.QuantizedDistribution(
+ distribution=tfd.TransformedDistribution(
+ distribution=tfd.Logistic(loc=loc, scale=scale),
+ bijector=tfb.AffineScalar(shift=-0.5)),
+ low=0.,
+ high=2**16 - 1.)
+ mixture_dist = tfd.MixtureSameFamily(
+ mixture_distribution=tfd.Categorical(logits=logits),
+ components_distribution=discretized_logistic_dist)
+
+ neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets))
+ train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood)
+ ```
+
+ After instantiating `mixture_dist`, we illustrate maximum likelihood by
+ calculating its log-probability of audio samples as `target` and optimizing.
+
+ #### References
- Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than
- a closed form function such as for a Poisson), computations such as mean and
- entropy are better done with samples or approximations, and are not
- implemented by this class.
+ [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
+ PixelCNN++: Improving the PixelCNN with discretized logistic mixture
+ likelihood and other modifications.
+ _International Conference on Learning Representations_, 2017.
+ https://arxiv.org/abs/1701.05517
+ [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech
+ Synthesis. _arXiv preprint arXiv:1711.10433_, 2017.
+ https://arxiv.org/abs/1711.10433
"""
def __init__(self,