aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-07 15:28:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 15:31:28 -0700
commite73c66f8152690b9f2466bfcca887283ed380980 (patch)
tree071cbbca8e6cfa5e796fd7e6f61202bd93be6ae7 /tensorflow/contrib/distributions
parent9f640dc874dba2e10b634cb7e87837f040fa83dc (diff)
Add ScaleTriL Bijector to enable transformed distributions over PSD matrices.
PiperOrigin-RevId: 199706732
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/BUILD19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py69
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py114
4 files changed, 204 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 61d4e90ea2..51f7028566 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -1138,6 +1138,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "scale_tril_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "sigmoid_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/sigmoid_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py
new file mode 100644
index 0000000000..566a7b3dff
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py
@@ -0,0 +1,69 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ScaleTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class ScaleTriLBijectorTest(test.TestCase):
+ """Tests the correctness of the ScaleTriL bijector."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def testComputesCorrectValues(self):
+ shift = 1.61803398875
+ x = np.float32(np.array([-1, .5, 2]))
+ y = np.float32(np.array([[np.exp(2) + shift, 0.],
+ [.5, np.exp(-1) + shift]]))
+
+ b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(),
+ diag_shift=shift)
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvertible(self):
+
+ # Generate random inputs from an unconstrained space, with
+ # event size 6 to specify 3x3 triangular matrices.
+ batch_shape = [2, 1]
+ x = np.float32(np.random.randn(*(batch_shape + [6])))
+ b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(),
+ diag_shift=3.14159)
+ y = self.evaluate(b.forward(x))
+ self.assertAllEqual(y.shape, batch_shape + [3, 3])
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllClose(fldj, -ildj)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index d97a1f0d30..e141f8b5c6 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -37,6 +37,7 @@
@@PowerTransform
@@RealNVP
@@Reshape
+@@ScaleTriL
@@Sigmoid
@@SinhArcsinh
@@SoftmaxCentered
@@ -78,6 +79,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import *
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import *
+from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import *
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import *
from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
new file mode 100644
index 0000000000..96bd242c63
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
@@ -0,0 +1,114 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ScaleTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar
+from tensorflow.contrib.distributions.python.ops.bijectors import chain
+from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular
+from tensorflow.contrib.distributions.python.ops.bijectors import softplus
+from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal
+
+__all__ = [
+ "ScaleTriL",
+]
+
+
+class ScaleTriL(chain.Chain):
+ """Transforms unconstrained vectors to TriL matrices with positive diagonal.
+
+ This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular`
+ followed by `tfb.TransformDiagonal`, and provided mostly as a
+ convenience. The default setup is somewhat opinionated, using a
+ Softplus transformation followed by a small shift (`1e-5`) which
+ attempts to avoid numerical issues from zeros on the diagonal.
+
+ #### Examples
+
+ ```python
+ tfb = tf.contrib.distributions.bijectors
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Exp(),
+ diag_shift=None)
+ b.forward(x=[0., 0., 0.])
+ # Result: [[1., 0.],
+ # [0., 1.]]
+ b.inverse(y=[[1., 0],
+ [.5, 2]])
+ # Result: [log(2), .5, log(1)]
+
+ # Define a distribution over PSD matrices of shape `[3, 3]`,
+ # with `1 + 2 + 3 = 6` degrees of freedom.
+ dist = tfd.TransformedDistribution(
+ tfd.Normal(tf.zeros(6), tf.ones(6)),
+ tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()]))
+
+ # Using an identity transformation, ScaleTriL is equivalent to
+ # tfb.FillTriangular.
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Identity(),
+ diag_shift=None)
+
+ # For greater control over initialization, one can manually encode
+ # pre- and post- shifts inside of `diag_bijector`.
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Chain([
+ tfb.AffineScalar(shift=1e-3),
+ tfb.Softplus(),
+ tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.)
+ # = log(expm1(1.)) = 0.5413
+ diag_shift=None)
+ ```
+ """
+
+ def __init__(self,
+ diag_bijector=None,
+ diag_shift=1e-5,
+ validate_args=False,
+ name="scale_tril"):
+ """Instantiates the `ScaleTriL` bijector.
+
+ Args:
+ diag_bijector: `Bijector` instance, used to transform the output diagonal
+ to be positive.
+ Default value: `None` (i.e., `tfb.Softplus()`).
+ diag_shift: Float value broadcastable and added to all diagonal entries
+ after applying the `diag_bijector`. Setting a positive
+ value forces the output diagonal entries to be positive, but
+ prevents inverting the transformation for matrices with
+ diagonal entries less than this value.
+ Default value: `1e-5` (i.e., no shift is applied).
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ Default value: `False` (i.e., arguments are not validated).
+ name: Python `str` name given to ops managed by this object.
+ Default value: `scale_tril`.
+ """
+
+ if diag_bijector is None:
+ diag_bijector = softplus.Softplus(validate_args=validate_args)
+
+ if diag_shift is not None:
+ diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift),
+ diag_bijector])
+
+ super(ScaleTriL, self).__init__(
+ [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
+ fill_triangular.FillTriangular()],
+ validate_args=validate_args,
+ name=name)