aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-07 12:52:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 12:55:19 -0700
commit9639db8d18d979e98061504a2c6ee4bba0f74610 (patch)
tree5547edc47e79e2a15f6cddadb748558fcab0eb7a /tensorflow/contrib/distributions
parent2857228ba6c7b357185e7a0af346f4fc93a10f74 (diff)
Add TransformDiagonal higher-order bijector to transform only the diagonal of a matrix.
PiperOrigin-RevId: 199680859
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/BUILD19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py66
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py102
4 files changed, 189 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index d8baf49e81..61d4e90ea2 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -1255,6 +1255,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "transform_diagonal_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "weibull_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/weibull_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py
new file mode 100644
index 0000000000..6428a68702
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TransformDiagonal bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class TransformDiagonalBijectorTest(test.TestCase):
+ """Tests correctness of the TransformDiagonal bijector."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBijector(self):
+ x = np.float32(np.random.randn(3, 4, 4))
+
+ y = x.copy()
+ for i in range(x.shape[0]):
+ np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :])))
+
+ exp = bijectors.Exp()
+ b = bijectors.TransformDiagonal(diag_bijector=exp)
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2))
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllEqual(
+ fldj,
+ self.evaluate(exp.forward_log_det_jacobian(
+ np.array([np.diag(x_mat) for x_mat in x]),
+ event_ndims=1)))
+ self.assertAllEqual(
+ ildj,
+ self.evaluate(exp.inverse_log_det_jacobian(
+ np.array([np.diag(y_mat) for y_mat in y]),
+ event_ndims=1)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index 59b8cf1bb2..d97a1f0d30 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -43,6 +43,7 @@
@@Softplus
@@Softsign
@@Square
+@@TransformDiagonal
@@Weibull
@@masked_autoregressive_default_template
@@ -83,6 +84,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered impo
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
from tensorflow.contrib.distributions.python.ops.bijectors.softsign import *
from tensorflow.contrib.distributions.python.ops.bijectors.square import *
+from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import *
from tensorflow.python.ops.distributions.bijector import *
from tensorflow.python.ops.distributions.identity_bijector import Identity
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py
new file mode 100644
index 0000000000..65669fc2bf
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py
@@ -0,0 +1,102 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TransformDiagonal bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions import bijector
+
+__all__ = [
+ "TransformDiagonal",
+]
+
+
+class TransformDiagonal(bijector.Bijector):
+ """Applies a Bijector to the diagonal of a matrix.
+
+ #### Example
+
+ ```python
+ b = tfb.TransformDiagonal(diag_bijector=tfb.Exp())
+
+ b.forward([[1., 0.],
+ [0., 1.]])
+ # ==> [[2.718, 0.],
+ [0., 2.718]]
+ ```
+
+ """
+
+ def __init__(self,
+ diag_bijector,
+ validate_args=False,
+ name="transform_diagonal"):
+ """Instantiates the `TransformDiagonal` bijector.
+
+ Args:
+ diag_bijector: `Bijector` instance used to transform the diagonal.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._diag_bijector = diag_bijector
+ super(TransformDiagonal, self).__init__(
+ forward_min_event_ndims=2,
+ inverse_min_event_ndims=2,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x))
+ return array_ops.matrix_set_diag(x, diag)
+
+ def _inverse(self, y):
+ diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y))
+ return array_ops.matrix_set_diag(y, diag)
+
+ def _forward_log_det_jacobian(self, x):
+ # We formulate the Jacobian with respect to the flattened matrices
+ # `vec(x)` and `vec(y)`. Suppose for notational convenience that
+ # the first `n` entries of `vec(x)` are the diagonal of `x`, and
+ # the remaining `n**2-n` entries are the off-diagonals in
+ # arbitrary order. Then the Jacobian is a block-diagonal matrix,
+ # with the Jacobian of the diagonal bijector in the first block,
+ # and the identity Jacobian for the remaining entries (since this
+ # bijector acts as the identity on non-diagonal entries):
+ #
+ # J_vec(x) (vec(y)) =
+ # -------------------------------
+ # | J_diag(x) (diag(y)) 0 | n entries
+ # | |
+ # | 0 I | n**2-n entries
+ # -------------------------------
+ # n n**2-n
+ #
+ # Since the log-det of the second (identity) block is zero, the
+ # overall log-det-jacobian is just the log-det of first block,
+ # from the diagonal bijector.
+ #
+ # Note that for elementwise operations (exp, softplus, etc) the
+ # first block of the Jacobian will itself be a diagonal matrix,
+ # but our implementation does not require this to be true.
+ return self._diag_bijector.forward_log_det_jacobian(
+ array_ops.matrix_diag_part(x), event_ndims=1)
+
+ def _inverse_log_det_jacobian(self, y):
+ return self._diag_bijector.inverse_log_det_jacobian(
+ array_ops.matrix_diag_part(y), event_ndims=1)