aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2017-12-01 13:36:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 13:40:02 -0800
commit71066d78d163bad92baf539cde54d167d758305e (patch)
tree4cf998246dac1cc17dd7d51e29252ec85e054842
parente88855650435327899917afb6723db03a3d5469f (diff)
Add `tf.contrib.distributions.Autoregressive`.
PiperOrigin-RevId: 177633858
-rw-r--r--tensorflow/contrib/distributions/BUILD13
-rw-r--r--tensorflow/contrib/distributions/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py94
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py208
-rw-r--r--tensorflow/python/ops/distributions/util.py1
5 files changed, 318 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 145b9495ff..c5bd91484e 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -128,6 +128,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "autoregressive_test",
+ size = "small",
+ srcs = ["python/kernel_tests/autoregressive_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "binomial_test",
size = "small",
srcs = ["python/kernel_tests/binomial_test.py"],
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 0d12d83893..a8cf40c52e 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -23,6 +23,7 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.contrib.distributions.python.ops.autoregressive import *
from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.cauchy import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
@@ -91,6 +92,7 @@ _allowed_symbols = [
'NOT_REPARAMETERIZED',
'ReparameterizationType',
'Distribution',
+ 'Autoregressive',
'Binomial',
'Bernoulli',
'BernoulliWithSigmoidProbs',
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py
new file mode 100644
index 0000000000..b625093fb7
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py
@@ -0,0 +1,94 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import autoregressive as autoregressive_lib
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.contrib.distributions.python.ops import test_util
+from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
+from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
+from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.platform import test
+
+
+class AutogressiveTest(test_util.VectorDistributionTestHelpers, test.TestCase):
+ """Tests the Autoregressive distribution."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def _random_scale_tril(self, event_size):
+ n = np.int32(event_size * (event_size + 1) // 2)
+ p = 2. * self._rng.random_sample(n).astype(np.float32) - 1.
+ return distribution_util.fill_triangular(0.25 * p)
+
+ def _normal_fn(self, affine_bijector):
+ def _fn(samples):
+ scale = math_ops.exp(affine_bijector.forward(samples))
+ return independent_lib.Independent(
+ normal_lib.Normal(loc=0., scale=scale, validate_args=True),
+ reinterpreted_batch_ndims=1)
+ return _fn
+
+ def testSampleAndLogProbConsistency(self):
+ batch_shape = []
+ event_size = 2
+ with self.test_session() as sess:
+ batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
+ sample0 = array_ops.zeros(batch_event_shape)
+ affine = Affine(scale_tril=self._random_scale_tril(event_size))
+ ar = autoregressive_lib.Autoregressive(
+ self._normal_fn(affine), sample0, validate_args=True)
+ self.run_test_sample_consistent_log_prob(
+ sess.run, ar, radius=1., center=0., rtol=0.01)
+
+ def testCompareToBijector(self):
+ """Demonstrates equivalence between TD, Bijector approach and AR dist."""
+ sample_shape = [4, 5]
+ batch_shape = []
+ event_size = 2
+ with self.test_session() as sess:
+ batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
+ sample0 = array_ops.zeros(batch_event_shape)
+ affine = Affine(scale_tril=self._random_scale_tril(event_size))
+ ar = autoregressive_lib.Autoregressive(
+ self._normal_fn(affine), sample0, validate_args=True)
+ ar_flow = MaskedAutoregressiveFlow(
+ is_constant_jacobian=True,
+ shift_and_log_scale_fn=lambda x: [None, affine.forward(x)],
+ validate_args=True)
+ td = transformed_distribution_lib.TransformedDistribution(
+ distribution=normal_lib.Normal(loc=0., scale=1.),
+ bijector=ar_flow,
+ event_shape=[event_size],
+ batch_shape=batch_shape,
+ validate_args=True)
+ x_shape = np.concatenate(
+ [sample_shape, batch_shape, [event_size]], axis=0)
+ x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1.
+ td_log_prob_, ar_log_prob_ = sess.run([td.log_prob(x), ar.log_prob(x)])
+ self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
new file mode 100644
index 0000000000..852298bf33
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -0,0 +1,208 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Autoregressive distribution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.ops.distributions import util as distribution_util
+
+
+class Autoregressive(distribution_lib.Distribution):
+ """Autoregressive distributions.
+
+ The Autoregressive distribution enables learning (often) richer multivariate
+ distributions by repeatedly applying a [diffeomorphic](
+ https://en.wikipedia.org/wiki/Diffeomorphism) transformation (such as
+ implemented by `Bijector`s). Regarding terminology,
+
+ "Autoregressive models decompose the joint density as a product of
+ conditionals, and model each conditional in turn. Normalizing flows
+ transform a base density (e.g. a standard Gaussian) into the target density
+ by an invertible transformation with tractable Jacobian." [1]
+
+ In other words, the "autoregressive property" is equivalent to the
+ decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided
+ `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves
+ this property by zeroing out weights in its `masked_dense` layers.
+
+ Practically speaking the autoregressive property means that there exists a
+ permutation of the event coordinates such that each coordinate is a
+ diffeomorphic function of only preceding coordinates. [2]
+
+ #### Mathematical Details
+
+ The probability function is,
+
+ ```none
+ prob(x; fn, n) = fn(x).prob(x)
+ ```
+
+ And a sample is generated by,
+
+ ```none
+ x = fn(...fn(fn(x0).sample()).sample()).sample()
+ ```
+
+ where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn`
+ constructs a `tf.distributions.Distribution`-like instance, and `x0` is a
+ fixed initializing `Tensor`.
+
+ #### Examples
+
+ ```python
+ tfd = tf.contrib.distributions
+
+ def normal_fn(self, event_size):
+ n = event_size * (event_size + 1) / 2
+ p = tf.Variable(tfd.Normal(loc=0., scale=1.).sample(n))
+ affine = tfd.bijectors.Affine(
+ scale_tril=tfd.fill_triangular(0.25 * p))
+ def _fn(samples):
+ scale = math_ops.exp(affine.forward(samples)).eval()
+ return independent_lib.Independent(
+ normal_lib.Normal(loc=0., scale=scale, validate_args=True),
+ reinterpreted_batch_ndims=1)
+ return _fn
+
+ batch_and_event_shape = [3, 2, 4]
+ sample0 = array_ops.zeros(batch_and_event_shape)
+ ar = autoregressive_lib.Autoregressive(
+ self._normal_fn(batch_and_event_shape[-1]), sample0)
+ x = ar.sample([6, 5])
+ # ==> x.shape = [6, 5, 3, 2, 4]
+ prob_x = ar.prob(x)
+ # ==> x.shape = [6, 5, 3, 2]
+
+ ```
+
+ [1]: "Masked Autoregressive Flow for Density Estimation."
+ George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017.
+ https://arxiv.org/abs/1705.07057
+
+ [2]: "Conditional Image Generation with PixelCNN Decoders."
+ Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex
+ Graves, Koray Kavukcuoglu. Arxiv, 2016.
+ https://arxiv.org/abs/1606.05328
+ """
+
+ def __init__(self,
+ distribution_fn,
+ sample0=None,
+ num_steps=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="Autoregressive"):
+ """Construct an `Autoregressive` distribution.
+
+ Args:
+ distribution_fn: Python `callable` which constructs a
+ `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
+ `sample0`). The function must respect the "autoregressive property",
+ i.e., there exists a permutation of event such that each coordinate is a
+ diffeomorphic function of on preceding coordinates.
+ sample0: Initial input to `distribution_fn`; used to
+ build the distribution in `__init__` which in turn specifies this
+ distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`.
+ If unspecified, then `distribution_fn` should be default constructable.
+ num_steps: Number of times `distribution_fn` is composed from samples,
+ e.g., `num_steps=2` implies
+ `distribution_fn(distribution_fn(sample0).sample(n)).sample()`.
+ validate_args: Python `bool`. Whether to validate input with asserts.
+ If `validate_args` is `False`, and the inputs are invalid,
+ correct behavior is not guaranteed.
+ allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: Python `str` name prefixed to Ops created by this class.
+ Default value: "Autoregressive".
+
+ Raises:
+ ValueError: if `num_steps` and
+ `distribution_fn(sample0).event_shape.num_elements()` are both `None`.
+ ValueError: if `num_steps < 1`.
+ """
+ parameters = locals()
+ with ops.name_scope(name):
+ self._distribution_fn = distribution_fn
+ self._sample0 = sample0
+ self._distribution0 = (distribution_fn() if sample0 is None
+ else distribution_fn(sample0))
+ if num_steps is None:
+ num_steps = self._distribution0.event_shape.num_elements()
+ if num_steps is None:
+ raise ValueError("distribution_fn must generate a distribution "
+ "with fully known `event_shape`.")
+ if num_steps < 1:
+ raise ValueError("num_steps ({}) must be at least 1.".format(num_steps))
+ self._num_steps = num_steps
+ super(Autoregressive, self).__init__(
+ dtype=self._distribution0.dtype,
+ reparameterization_type=self._distribution0.reparameterization_type,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=self._distribution0._graph_parents, # pylint: disable=protected-access
+ name=name)
+
+ @property
+ def distribution_fn(self):
+ return self._distribution_fn
+
+ @property
+ def sample0(self):
+ return self._sample0
+
+ @property
+ def num_steps(self):
+ return self._num_steps
+
+ @property
+ def distribution0(self):
+ return self._distribution0
+
+ def _batch_shape(self):
+ return self.distribution0.batch_shape
+
+ def _batch_shape_tensor(self):
+ return self.distribution0.batch_shape_tensor()
+
+ def _event_shape(self):
+ return self.distribution0.event_shape
+
+ def _event_shape_tensor(self):
+ return self.distribution0.event_shape_tensor()
+
+ def _sample_n(self, n, seed=None):
+ if seed is None:
+ seed = distribution_util.gen_new_seed(
+ seed=np.random.randint(2**32 - 1),
+ salt="autoregressive")
+ samples = self.distribution0.sample(n, seed=seed)
+ for _ in range(self._num_steps):
+ samples = self.distribution_fn(samples).sample(seed=seed)
+ return samples
+
+ def _log_prob(self, value):
+ return self.distribution_fn(value).log_prob(value)
+
+ def _prob(self, value):
+ return self.distribution_fn(value).prob(value)
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 41b86f7940..28c74bf981 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -751,6 +751,7 @@ def fill_triangular(x, upper=False, name=None):
"""
with ops.name_scope(name, "fill_triangular", values=[x]):
+ x = ops.convert_to_tensor(x, name="x")
if x.shape.with_rank_at_least(1)[-1].value is not None:
# Formula derived by solving for n: m = n(n+1)/2.
m = np.int32(x.shape[-1].value)