aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-05-01 14:28:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 14:33:20 -0700
commit325d0ef21a48bea1cc618a2bd24a9776de417ce5 (patch)
treed41cf6304071e95bebd5747ca87dfca571e98634 /tensorflow/contrib/distributions/python
parent46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (diff)
Merge changes from github.
PiperOrigin-RevId: 194997009
Diffstat (limited to 'tensorflow/contrib/distributions/python')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py109
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/invert.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/ordered.py125
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/permute.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/reshape.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/weibull.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/shape.py2
11 files changed, 249 insertions, 13 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
new file mode 100644
index 0000000000..a5f5219588
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
@@ -0,0 +1,109 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops.bijectors.ordered import Ordered
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.platform import test
+
+
+
+class OrderedBijectorTest(test.TestCase):
+ """Tests correctness of the ordered transformation."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBijectorVector(self):
+ with self.test_session():
+ ordered = Ordered()
+ self.assertEqual("ordered", ordered.name)
+ x = np.asarray([[2., 3, 4], [4., 8, 13]])
+ y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
+ self.assertAllClose(y, self.evaluate(ordered.forward(x)))
+ self.assertAllClose(x, self.evaluate(ordered.inverse(y)))
+ self.assertAllClose(
+ np.sum(np.asarray(y)[..., 1:], axis=-1),
+ self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
+ self.assertAllClose(
+ self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
+
+ def testBijectorUnknownShape(self):
+ with self.test_session():
+ ordered = Ordered()
+ self.assertEqual("ordered", ordered.name)
+ x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
+ real_x = np.asarray([[2., 3, 4], [4., 8, 13]])
+ y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
+ real_y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
+ self.assertAllClose(real_y, ordered.forward(x).eval(
+ feed_dict={x: real_x}))
+ self.assertAllClose(real_x, ordered.inverse(y).eval(
+ feed_dict={y: real_y}))
+ self.assertAllClose(
+ np.sum(np.asarray(real_y)[..., 1:], axis=-1),
+ ordered.inverse_log_det_jacobian(y, event_ndims=1).eval(
+ feed_dict={y: real_y}),
+ atol=0.,
+ rtol=1e-7)
+ self.assertAllClose(
+ -ordered.inverse_log_det_jacobian(y, event_ndims=1).eval(
+ feed_dict={y: real_y}),
+ ordered.forward_log_det_jacobian(x, event_ndims=1).eval(
+ feed_dict={x: real_x}),
+ atol=0.,
+ rtol=1e-7)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testShapeGetters(self):
+ with self.test_session():
+ x = tensor_shape.TensorShape([4])
+ y = tensor_shape.TensorShape([4])
+ bijector = Ordered(validate_args=True)
+ self.assertAllEqual(y, bijector.forward_event_shape(x))
+ self.assertAllEqual(y.as_list(),
+ self.evaluate(bijector.forward_event_shape_tensor(
+ x.as_list())))
+ self.assertAllEqual(x, bijector.inverse_event_shape(y))
+ self.assertAllEqual(x.as_list(),
+ self.evaluate(bijector.inverse_event_shape_tensor(
+ y.as_list())))
+
+ def testBijectiveAndFinite(self):
+ with self.test_session():
+ ordered = Ordered()
+ x = np.sort(self._rng.randn(3, 10), axis=-1).astype(np.float32)
+ y = (self._rng.randn(3, 10)).astype(np.float32)
+ assert_bijective_and_finite(ordered, x, y, event_ndims=1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index babce80396..51478dbeff 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -30,6 +30,7 @@
@@Invert
@@Kumaraswamy
@@MaskedAutoregressiveFlow
+@@Ordered
@@Permute
@@PowerTransform
@@RealNVP
@@ -67,6 +68,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import *
from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
+from tensorflow.contrib.distributions.python.ops.bijectors.ordered import *
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
index caae2adcfa..ecdb8967f4 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
@@ -170,7 +170,7 @@ class CholeskyOuterProduct(bijector.Bijector):
sum_weighted_log_diag = array_ops.squeeze(
math_ops.matmul(math_ops.log(diag),
exponents[..., array_ops.newaxis]),
- squeeze_dims=-1)
+ axis=-1)
fldj = p_float * np.log(2.) + sum_weighted_log_diag
return fldj
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py
index 1904239a0e..84a3289ba2 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py
@@ -18,14 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.ops.distributions import bijector as bijector_lib
+from tensorflow.python.ops.distributions import bijector
__all__ = [
"Invert",
]
-class Invert(bijector_lib.Bijector):
+class Invert(bijector.Bijector):
"""Bijector which inverts another Bijector.
Example Use: [ExpGammaDistribution (see Background & Context)](
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index ef56cf6ddd..83667b0e80 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import template as template_ops
from tensorflow.python.ops import variable_scope as variable_scope_lib
-from tensorflow.python.ops.distributions import bijector as bijector_lib
+from tensorflow.python.ops.distributions import bijector
__all__ = [
@@ -42,7 +42,7 @@ __all__ = [
]
-class MaskedAutoregressiveFlow(bijector_lib.Bijector):
+class MaskedAutoregressiveFlow(bijector.Bijector):
"""Affine MaskedAutoregressiveFlow bijector for vector-valued events.
The affine autoregressive flow [(Papamakarios et al., 2016)][3] provides a
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py
new file mode 100644
index 0000000000..3f03592f31
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ordered bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+
+__all__ = [
+ "Ordered",
+]
+
+
+class Ordered(bijector.Bijector):
+ """Bijector which maps a tensor x_k that has increasing elements in the last
+ dimension to an unconstrained tensor y_k.
+
+ Both the domain and the codomain of the mapping is [-inf, inf], however,
+ the input of the forward mapping must be strictly increasing.
+ The inverse of the bijector applied to a normal random vector `y ~ N(0, 1)`
+ gives back a sorted random vector with the same distribution `x ~ N(0, 1)`
+ where `x = sort(y)`
+
+ On the last dimension of the tensor, Ordered bijector performs:
+ `y[0] = x[0]`
+ `y[1:] = math_ops.log(x[1:] - x[:-1])`
+
+ #### Example Use:
+
+ ```python
+ bijector.Ordered().forward([2, 3, 4])
+ # Result: [2., 0., 0.]
+
+ bijector.Ordered().inverse([0.06428002, -1.07774478, -0.71530371])
+ # Result: [0.06428002, 0.40464228, 0.8936858]
+ ```
+ """
+
+ def __init__(self, validate_args=False, name="ordered"):
+ super(Ordered, self).__init__(
+ forward_min_event_ndims=1,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward_event_shape(self, input_shape):
+ if input_shape.ndims is None or input_shape[-1] is None:
+ return input_shape
+ return tensor_shape.TensorShape([input_shape[-1]])
+
+ def _forward_event_shape_tensor(self, input_shape):
+ return (input_shape[-1])[..., array_ops.newaxis]
+
+ def _inverse_event_shape(self, output_shape):
+ if output_shape.ndims is None or output_shape[-1] is None:
+ return output_shape
+ if output_shape[-1] <= 1:
+ raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1])
+ return tensor_shape.TensorShape([output_shape[-1]])
+
+ def _inverse_event_shape_tensor(self, output_shape):
+ if self.validate_args:
+ is_greater_one = check_ops.assert_greater(
+ output_shape[-1], 1, message="Need last dimension greater than 1.")
+ output_shape = control_flow_ops.with_dependencies(
+ [is_greater_one], output_shape)
+ return (output_shape[-1])[..., array_ops.newaxis]
+
+ def _forward(self, x):
+ x = self._maybe_assert_valid_x(x)
+ y0 = x[..., 0, array_ops.newaxis]
+ yk = math_ops.log(x[..., 1:] - x[..., :-1])
+ y = array_ops.concat([y0, yk], axis=-1)
+ return y
+
+ def _inverse(self, y):
+ x0 = y[..., 0, array_ops.newaxis]
+ xk = math_ops.exp(y[..., 1:])
+ x = array_ops.concat([x0, xk], axis=-1)
+ return math_ops.cumsum(x, axis=-1)
+
+ def _inverse_log_det_jacobian(self, y):
+ # The Jacobian of the inverse mapping is lower
+ # triangular, with the diagonal elements being:
+ # J[i,i] = 1 if i=1, and
+ # exp(y_i) if 1<i<=K
+ # which gives the absolute Jacobian determinant:
+ # |det(Jac)| = prod_{i=1}^{K} exp(y[i]).
+ # (1) - Stan Modeling Language User's Guide and Reference Manual
+ # Version 2.17.0 session 35.2
+ return math_ops.reduce_sum(y[..., 1:], axis=-1)
+
+ def _forward_log_det_jacobian(self, x):
+ x = self._maybe_assert_valid_x(x)
+ return -math_ops.reduce_sum(
+ math_ops.log(x[..., 1:] - x[..., :-1]),
+ axis=-1)
+
+ def _maybe_assert_valid_x(self, x):
+ if not self.validate_args:
+ return x
+ is_valid = check_ops.assert_positive(
+ x[..., 1:] - x[..., :-1],
+ message="Forward transformation input must be strictly increasing.")
+ return control_flow_ops.with_dependencies([is_valid], x)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
index 4978167803..12a16a3f2b 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
@@ -28,7 +28,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops.distributions import bijector as bijector_lib
+from tensorflow.python.ops.distributions import bijector
__all__ = [
@@ -36,7 +36,7 @@ __all__ = [
]
-class Permute(bijector_lib.Bijector):
+class Permute(bijector.Bijector):
"""Permutes the rightmost dimension of a `Tensor`.
```python
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
index f09ab21bce..66e8a5b9b3 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
@@ -25,7 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import template as template_ops
-from tensorflow.python.ops.distributions import bijector as bijector_lib
+from tensorflow.python.ops.distributions import bijector
__all__ = [
@@ -34,7 +34,7 @@ __all__ = [
]
-class RealNVP(bijector_lib.Bijector):
+class RealNVP(bijector.Bijector):
"""RealNVP "affine coupling layer" for vector-valued events.
Real NVP models a normalizing flow on a `D`-dimensional distribution via a
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
index f21b982ba6..5497c422e4 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
@@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import bijector as bijector_lib
+from tensorflow.python.ops.distributions import bijector
__all__ = [
@@ -44,7 +44,7 @@ def _ndims_from_shape(shape):
return array_ops.shape(shape)[0]
-class Reshape(bijector_lib.Bijector):
+class Reshape(bijector.Bijector):
"""Reshapes the `event_shape` of a `Tensor`.
The semantics generally follow that of `tf.reshape()`, with
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py
index 39129cd22c..a22560fe80 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py
@@ -128,7 +128,7 @@ class Weibull(bijector.Bijector):
return x
is_valid = check_ops.assert_non_negative(
x,
- message="Forward transformation input must be at least {}.".format(0))
+ message="Forward transformation input must be at least 0.")
return control_flow_ops.with_dependencies([is_valid], x)
def _maybe_assert_valid_y(self, y):
diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py
index bac0b79d59..6a7f28713a 100644
--- a/tensorflow/contrib/distributions/python/ops/shape.py
+++ b/tensorflow/contrib/distributions/python/ops/shape.py
@@ -439,7 +439,7 @@ class _DistributionShape(object):
if self._batch_ndims_is_0 and expand_batch_dim:
squeeze_dims += [1]
if squeeze_dims:
- x = array_ops.squeeze(x, squeeze_dims=squeeze_dims)
+ x = array_ops.squeeze(x, axis=squeeze_dims)
# x.shape: [prod(S)]+B+E
_, batch_shape, event_shape = self.get_shape(x)
else: