aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-07 11:45:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 11:47:20 -0700
commit642dc96bd4627a4f6305cf61b8553324054d9122 (patch)
tree49bc12899ff6ca0e631746b69cd91f34f8f52366 /tensorflow/contrib/distributions
parente343b8072833765c85a5685b0f56b1b3d6add275 (diff)
Add FillTriangular Bijector to create triangular matrices.
PiperOrigin-RevId: 199670547
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/BUILD19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py98
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py148
4 files changed, 267 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 23d9dbcd91..d8baf49e81 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -941,6 +941,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "fill_triangular_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "gumbel_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/gumbel_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py
new file mode 100644
index 0000000000..caeaf2a0c6
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for FillTriangular bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class FillTriangularBijectorTest(test.TestCase):
+ """Tests the correctness of the FillTriangular bijector."""
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBijector(self):
+ x = np.float32(np.array([1., 2., 3.]))
+ y = np.float32(np.array([[3., 0.],
+ [2., 1.]]))
+
+ b = bijectors.FillTriangular()
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
+ self.assertAllClose(fldj, 0.)
+
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllClose(ildj, 0.)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testShape(self):
+ x_shape = tensor_shape.TensorShape([5, 4, 6])
+ y_shape = tensor_shape.TensorShape([5, 4, 3, 3])
+
+ b = bijectors.FillTriangular(validate_args=True)
+
+ x = array_ops.ones(shape=x_shape, dtype=dtypes.float32)
+ y_ = b.forward(x)
+ self.assertAllEqual(y_.shape.as_list(), y_shape.as_list())
+ x_ = b.inverse(y_)
+ self.assertAllEqual(x_.shape.as_list(), x_shape.as_list())
+
+ y_shape_ = b.forward_event_shape(x_shape)
+ self.assertAllEqual(y_shape_.as_list(), y_shape.as_list())
+ x_shape_ = b.inverse_event_shape(y_shape)
+ self.assertAllEqual(x_shape_.as_list(), x_shape.as_list())
+
+ y_shape_tensor = self.evaluate(
+ b.forward_event_shape_tensor(x_shape.as_list()))
+ self.assertAllEqual(y_shape_tensor, y_shape.as_list())
+ x_shape_tensor = self.evaluate(
+ b.inverse_event_shape_tensor(y_shape.as_list()))
+ self.assertAllEqual(x_shape_tensor, x_shape.as_list())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testShapeError(self):
+
+ b = bijectors.FillTriangular(validate_args=True)
+
+ x_shape_bad = tensor_shape.TensorShape([5, 4, 7])
+ with self.assertRaisesRegexp(ValueError, "is not a triangular number"):
+ b.forward_event_shape(x_shape_bad)
+ with self.assertRaisesOpError("is not a triangular number"):
+ self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list()))
+
+ y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2])
+ with self.assertRaisesRegexp(ValueError, "Matrix must be square"):
+ b.inverse_event_shape(y_shape_bad)
+ with self.assertRaisesOpError("Matrix must be square"):
+ self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list()))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index 4965381ef3..59b8cf1bb2 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -24,6 +24,7 @@
@@CholeskyOuterProduct
@@ConditionalBijector
@@Exp
+@@FillTriangular
@@Gumbel
@@Identity
@@Inline
@@ -64,6 +65,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
+from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import *
from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py
new file mode 100644
index 0000000000..7b06325ead
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py
@@ -0,0 +1,148 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FillTriangular bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.ops.distributions import util as dist_util
+
+
+__all__ = [
+ "FillTriangular",
+]
+
+
+class FillTriangular(bijector.Bijector):
+ """Transforms vectors to triangular.
+
+ Triangular matrix elements are filled in a clockwise spiral.
+
+ Given input with shape `batch_shape + [d]`, produces output with
+ shape `batch_shape + [n, n]`, where
+ `n = (-1 + sqrt(1 + 8 * d))/2`.
+ This follows by solving the quadratic equation
+ `d = 1 + 2 + ... + n = n * (n + 1)/2`.
+
+ #### Example
+
+ ```python
+ b = tfb.FillTriangular(upper=False)
+ b.forward([1, 2, 3, 4, 5, 6])
+ # ==> [[4, 0, 0],
+ # [6, 5, 0],
+ # [3, 2, 1]]
+
+ b = tfb.FillTriangular(upper=True)
+ b.forward([1, 2, 3, 4, 5, 6])
+ # ==> [[1, 2, 3],
+ # [0, 5, 6],
+ # [0, 0, 4]]
+
+ ```
+ """
+
+ def __init__(self,
+ upper=False,
+ validate_args=False,
+ name="fill_triangular"):
+ """Instantiates the `FillTriangular` bijector.
+
+ Args:
+ upper: Python `bool` representing whether output matrix should be upper
+ triangular (`True`) or lower triangular (`False`, default).
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._upper = upper
+ super(FillTriangular, self).__init__(
+ forward_min_event_ndims=1,
+ inverse_min_event_ndims=2,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ return dist_util.fill_triangular(x, upper=self._upper)
+
+ def _inverse(self, y):
+ return dist_util.fill_triangular_inverse(y, upper=self._upper)
+
+ def _forward_log_det_jacobian(self, x):
+ return array_ops.zeros_like(x[..., 0])
+
+ def _inverse_log_det_jacobian(self, y):
+ return array_ops.zeros_like(y[..., 0, 0])
+
+ def _forward_event_shape(self, input_shape):
+ batch_shape, d = input_shape[:-1], input_shape[-1].value
+ if d is None:
+ n = None
+ else:
+ n = vector_size_to_square_matrix_size(d, self.validate_args)
+ return batch_shape.concatenate([n, n])
+
+ def _inverse_event_shape(self, output_shape):
+ batch_shape, n1, n2 = (output_shape[:-2],
+ output_shape[-2].value,
+ output_shape[-1].value)
+ if n1 is None or n2 is None:
+ m = None
+ elif n1 != n2:
+ raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2))
+ else:
+ m = n1 * (n1 + 1) / 2
+ return batch_shape.concatenate([m])
+
+ def _forward_event_shape_tensor(self, input_shape_tensor):
+ batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1]
+ n = vector_size_to_square_matrix_size(d, self.validate_args)
+ return array_ops.concat([batch_shape, [n, n]], axis=0)
+
+ def _inverse_event_shape_tensor(self, output_shape_tensor):
+ batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1]
+ if self.validate_args:
+ is_square_matrix = check_ops.assert_equal(
+ n, output_shape_tensor[-2], message="Matrix must be square.")
+ with ops.control_dependencies([is_square_matrix]):
+ n = array_ops.identity(n)
+ d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype)
+ return array_ops.concat([batch_shape, [d]], axis=0)
+
+
+def vector_size_to_square_matrix_size(d, validate_args, name=None):
+ """Convert a vector size to a matrix size."""
+ if isinstance(d, (float, int, np.generic, np.ndarray)):
+ n = (-1 + np.sqrt(1 + 8 * d)) / 2.
+ if float(int(n)) != n:
+ raise ValueError("Vector length is not a triangular number.")
+ return int(n)
+ else:
+ with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name:
+ n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2.
+ if validate_args:
+ with ops.control_dependencies([check_ops.assert_equal(
+ math_ops.to_float(math_ops.to_int32(n)), n,
+ message="Vector length is not a triangular number")]):
+ n = array_ops.identity(n)
+ return math_ops.cast(n, d.dtype)