aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-16 11:17:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 11:21:33 -0700
commit2487732ff111daedaf489672700ccfbf2088c3de (patch)
tree4d37982fa74e33f3361830a2bf3513899f75c13e
parentf0e3edf8b1c8de49672d78abe73dcd0b1f02620c (diff)
Add tf.contrib.distributions.bijectors.Gumbel.
PiperOrigin-RevId: 172350038
-rw-r--r--tensorflow/contrib/distributions/BUILD19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py70
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py29
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py124
5 files changed, 244 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 93770c37de..825ec652d0 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -798,6 +798,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "gumbel_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/gumbel_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "inline_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/inline_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
new file mode 100644
index 0000000000..9a905980c7
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
@@ -0,0 +1,70 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+
+from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import Gumbel
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
+from tensorflow.python.platform import test
+
+
+class GumbelBijectorTest(test.TestCase):
+ """Tests correctness of the Gumbel bijector."""
+
+ def testBijector(self):
+ with self.test_session():
+ loc = 0.3
+ scale = 5.
+ bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True)
+ self.assertEqual("gumbel", bijector.name)
+ x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32)
+ # Gumbel distribution
+ gumbel_dist = stats.gumbel_r(loc=loc, scale=scale)
+ y = gumbel_dist.cdf(x).astype(np.float32)
+ self.assertAllClose(y, bijector.forward(x).eval())
+ self.assertAllClose(x, bijector.inverse(y).eval())
+ self.assertAllClose(
+ # We should lose a dimension from calculating the determinant of the
+ # jacobian.
+ np.squeeze(gumbel_dist.logpdf(x), axis=2),
+ bijector.forward_log_det_jacobian(x).eval())
+ self.assertAllClose(
+ -bijector.inverse_log_det_jacobian(y).eval(),
+ bijector.forward_log_det_jacobian(x).eval(),
+ rtol=1e-4,
+ atol=0.)
+
+ def testScalarCongruency(self):
+ with self.test_session():
+ assert_scalar_congruency(
+ Gumbel(loc=0.3, scale=20.), lower_x=1., upper_x=100., rtol=0.02)
+
+ def testBijectiveAndFinite(self):
+ with self.test_session():
+ bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True)
+ x = np.linspace(-10., 10., num=10).astype(np.float32)
+ y = np.linspace(0.01, 0.99, num=10).astype(np.float32)
+ assert_bijective_and_finite(bijector, x, y, rtol=1e-3)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index c9ed546a34..e62f900bbf 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -22,6 +22,7 @@
@@CholeskyOuterProduct
@@ConditionalBijector
@@Exp
+@@Gumbel
@@Identity
@@Inline
@@Invert
@@ -48,6 +49,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
+from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py
new file mode 100644
index 0000000000..cf37aa5111
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py
@@ -0,0 +1,29 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gumbel bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = ["Gumbel"]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py
new file mode 100644
index 0000000000..67f3978556
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gumbel bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+__all__ = [
+ "Gumbel",
+]
+
+
+class Gumbel(bijector.Bijector):
+ """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
+
+ This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the
+ bijector applied to a uniform random variable `X ~ U(0, 1) gives back a
+ random variable with the
+ [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution):
+
+ ```none
+ Y ~ Gumbel(loc, scale)
+ pdf(y; loc, scale) = exp(
+ -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale
+ ```
+ """
+
+ def __init__(self,
+ loc=0.,
+ scale=1.,
+ event_ndims=0,
+ validate_args=False,
+ name="gumbel"):
+ """Instantiates the `Gumbel` bijector.
+
+ Args:
+ loc: Float-like `Tensor` that is the same dtype and is
+ broadcastable with `scale`.
+ This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
+ scale: Positive Float-like `Tensor` that is the same dtype and is
+ broadcastable with `loc`.
+ This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
+ event_ndims: Python scalar indicating the number of dimensions associated
+ with a particular draw from the distribution.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._graph_parents = []
+ self._name = name
+ self._validate_args = validate_args
+ with self._name_scope("init", values=[loc, scale]):
+ self._loc = ops.convert_to_tensor(loc, name="loc")
+ self._scale = ops.convert_to_tensor(scale, name="scale")
+ check_ops.assert_same_float_dtype([self._loc, self._scale])
+ if validate_args:
+ self._scale = control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ self._scale, message="Argument scale was not positive")
+ ], self._scale)
+
+ super(Gumbel, self).__init__(
+ event_ndims=event_ndims, validate_args=validate_args, name=name)
+
+ @property
+ def loc(self):
+ """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`."""
+ return self._loc
+
+ @property
+ def scale(self):
+ """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`."""
+ return self._scale
+
+ def _forward(self, x):
+ z = (x - self.loc) / self.scale
+ return math_ops.exp(-math_ops.exp(-z))
+
+ def _inverse(self, y):
+ y = self._maybe_assert_valid_y(y)
+ return self.loc - self.scale * math_ops.log(-math_ops.log(y))
+
+ def _inverse_log_det_jacobian(self, y):
+ y = self._maybe_assert_valid_y(y)
+ event_dims = self._event_dims_tensor(y)
+ return math_ops.reduce_sum(
+ math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims)
+
+ def _forward_log_det_jacobian(self, x):
+ event_dims = self._event_dims_tensor(x)
+ z = (x - self.loc) / self.scale
+ return math_ops.reduce_sum(
+ -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims)
+
+ def _maybe_assert_valid_y(self, y):
+ if not self.validate_args:
+ return y
+ is_positive = check_ops.assert_non_negative(
+ y, message="Inverse transformation input must be greater than 0.")
+ less_than_one = check_ops.assert_less_equal(
+ y,
+ constant_op.constant(1., y.dtype),
+ message="Inverse transformation input must be less than or equal to 1.")
+ return control_flow_ops.with_dependencies([is_positive, less_than_one], y)