diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-07 11:45:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-07 11:47:20 -0700 |
commit | 642dc96bd4627a4f6305cf61b8553324054d9122 (patch) | |
tree | 49bc12899ff6ca0e631746b69cd91f34f8f52366 /tensorflow/contrib/distributions | |
parent | e343b8072833765c85a5685b0f56b1b3d6add275 (diff) |
Add FillTriangular Bijector to create triangular matrices.
PiperOrigin-RevId: 199670547
Diffstat (limited to 'tensorflow/contrib/distributions')
4 files changed, 267 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 23d9dbcd91..d8baf49e81 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -941,6 +941,25 @@ cuda_py_test( ) cuda_py_test( + name = "fill_triangular_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( name = "gumbel_test", size = "small", srcs = ["python/kernel_tests/bijectors/gumbel_test.py"], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py new file mode 100644 index 0000000000..caeaf2a0c6 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class FillTriangularBijectorTest(test.TestCase): + """Tests the correctness of the FillTriangular bijector.""" + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.array([1., 2., 3.])) + y = np.float32(np.array([[3., 0.], + [2., 1.]])) + + b = bijectors.FillTriangular() + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + self.assertAllClose(fldj, 0.) + + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(ildj, 0.) + + @test_util.run_in_graph_and_eager_modes() + def testShape(self): + x_shape = tensor_shape.TensorShape([5, 4, 6]) + y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) + + b = bijectors.FillTriangular(validate_args=True) + + x = array_ops.ones(shape=x_shape, dtype=dtypes.float32) + y_ = b.forward(x) + self.assertAllEqual(y_.shape.as_list(), y_shape.as_list()) + x_ = b.inverse(y_) + self.assertAllEqual(x_.shape.as_list(), x_shape.as_list()) + + y_shape_ = b.forward_event_shape(x_shape) + self.assertAllEqual(y_shape_.as_list(), y_shape.as_list()) + x_shape_ = b.inverse_event_shape(y_shape) + self.assertAllEqual(x_shape_.as_list(), x_shape.as_list()) + + y_shape_tensor = self.evaluate( + b.forward_event_shape_tensor(x_shape.as_list())) + self.assertAllEqual(y_shape_tensor, y_shape.as_list()) + x_shape_tensor = self.evaluate( + b.inverse_event_shape_tensor(y_shape.as_list())) + self.assertAllEqual(x_shape_tensor, x_shape.as_list()) + + @test_util.run_in_graph_and_eager_modes() + def testShapeError(self): + + b = bijectors.FillTriangular(validate_args=True) + + x_shape_bad = tensor_shape.TensorShape([5, 4, 7]) + with self.assertRaisesRegexp(ValueError, "is not a triangular number"): + b.forward_event_shape(x_shape_bad) + with self.assertRaisesOpError("is not a triangular number"): + self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list())) + + y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2]) + with self.assertRaisesRegexp(ValueError, "Matrix must be square"): + b.inverse_event_shape(y_shape_bad) + with self.assertRaisesOpError("Matrix must be square"): + self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list())) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 4965381ef3..59b8cf1bb2 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -24,6 +24,7 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@FillTriangular @@Gumbel @@Identity @@Inline @@ -64,6 +65,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py new file mode 100644 index 0000000000..7b06325ead --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -0,0 +1,148 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as dist_util + + +__all__ = [ + "FillTriangular", +] + + +class FillTriangular(bijector.Bijector): + """Transforms vectors to triangular. + + Triangular matrix elements are filled in a clockwise spiral. + + Given input with shape `batch_shape + [d]`, produces output with + shape `batch_shape + [n, n]`, where + `n = (-1 + sqrt(1 + 8 * d))/2`. + This follows by solving the quadratic equation + `d = 1 + 2 + ... + n = n * (n + 1)/2`. + + #### Example + + ```python + b = tfb.FillTriangular(upper=False) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + + b = tfb.FillTriangular(upper=True) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + + ``` + """ + + def __init__(self, + upper=False, + validate_args=False, + name="fill_triangular"): + """Instantiates the `FillTriangular` bijector. + + Args: + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._upper = upper + super(FillTriangular, self).__init__( + forward_min_event_ndims=1, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return dist_util.fill_triangular(x, upper=self._upper) + + def _inverse(self, y): + return dist_util.fill_triangular_inverse(y, upper=self._upper) + + def _forward_log_det_jacobian(self, x): + return array_ops.zeros_like(x[..., 0]) + + def _inverse_log_det_jacobian(self, y): + return array_ops.zeros_like(y[..., 0, 0]) + + def _forward_event_shape(self, input_shape): + batch_shape, d = input_shape[:-1], input_shape[-1].value + if d is None: + n = None + else: + n = vector_size_to_square_matrix_size(d, self.validate_args) + return batch_shape.concatenate([n, n]) + + def _inverse_event_shape(self, output_shape): + batch_shape, n1, n2 = (output_shape[:-2], + output_shape[-2].value, + output_shape[-1].value) + if n1 is None or n2 is None: + m = None + elif n1 != n2: + raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) + else: + m = n1 * (n1 + 1) / 2 + return batch_shape.concatenate([m]) + + def _forward_event_shape_tensor(self, input_shape_tensor): + batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] + n = vector_size_to_square_matrix_size(d, self.validate_args) + return array_ops.concat([batch_shape, [n, n]], axis=0) + + def _inverse_event_shape_tensor(self, output_shape_tensor): + batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] + if self.validate_args: + is_square_matrix = check_ops.assert_equal( + n, output_shape_tensor[-2], message="Matrix must be square.") + with ops.control_dependencies([is_square_matrix]): + n = array_ops.identity(n) + d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) + return array_ops.concat([batch_shape, [d]], axis=0) + + +def vector_size_to_square_matrix_size(d, validate_args, name=None): + """Convert a vector size to a matrix size.""" + if isinstance(d, (float, int, np.generic, np.ndarray)): + n = (-1 + np.sqrt(1 + 8 * d)) / 2. + if float(int(n)) != n: + raise ValueError("Vector length is not a triangular number.") + return int(n) + else: + with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + if validate_args: + with ops.control_dependencies([check_ops.assert_equal( + math_ops.to_float(math_ops.to_int32(n)), n, + message="Vector length is not a triangular number")]): + n = array_ops.identity(n) + return math_ops.cast(n, d.dtype) |