aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
blob: 49a9afe3f6debe048369c52328fb5534946ab9e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MatrixInverseTriL bijector."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test


class MatrixInverseTriLBijectorTest(test.TestCase):
  """Tests the correctness of the Y = inv(tril) transformation."""

  #The inverse of 0 is undefined, as the numbers above the main
  #diagonal must be zero, we zero out these numbers after running inverse.
  #See: https://github.com/numpy/numpy/issues/11445
  def _inv(self, x):
    y = np.linalg.inv(x)
    #triu_indices only works on 2d arrays
    #need to iterate over all the 2d arrays in a x-dimensional array.
    for idx in np.ndindex(y.shape[0:-2]):
      y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0
    return y

  @test_util.run_in_graph_and_eager_modes
  def testComputesCorrectValues(self):
    inv = bijectors.MatrixInverseTriL(validate_args=True)
    self.assertEqual("matrix_inverse_tril", inv.name)
    x_ = np.array([[0.7, 0., 0.],
                   [0.1, -1., 0.],
                   [0.3, 0.25, 0.5]], dtype=np.float32)
    x_inv_ = np.linalg.inv(x_)
    expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_))))

    y = inv.forward(x_)
    x_back = inv.inverse(x_inv_)
    fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
    ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)

    y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])

    self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
    self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
    self.assertNear(expected_fldj_, fldj_, err=1e-3)
    self.assertNear(-expected_fldj_, ildj_, err=1e-3)

  @test_util.run_in_graph_and_eager_modes
  def testOneByOneMatrix(self):
    inv = bijectors.MatrixInverseTriL(validate_args=True)
    x_ = np.array([[5.]], dtype=np.float32)
    x_inv_ = np.array([[0.2]], dtype=np.float32)
    expected_fldj_ = np.log(0.04)

    y = inv.forward(x_)
    x_back = inv.inverse(x_inv_)
    fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
    ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)

    y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])

    self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
    self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
    self.assertNear(expected_fldj_, fldj_, err=1e-3)
    self.assertNear(-expected_fldj_, ildj_, err=1e-3)

  @test_util.run_in_graph_and_eager_modes
  def testZeroByZeroMatrix(self):
    inv = bijectors.MatrixInverseTriL(validate_args=True)
    x_ = np.eye(0, dtype=np.float32)
    x_inv_ = np.eye(0, dtype=np.float32)
    expected_fldj_ = 0.

    y = inv.forward(x_)
    x_back = inv.inverse(x_inv_)
    fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
    ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)

    y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])

    self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
    self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
    self.assertNear(expected_fldj_, fldj_, err=1e-3)
    self.assertNear(-expected_fldj_, ildj_, err=1e-3)

  @test_util.run_in_graph_and_eager_modes
  def testBatch(self):
    # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape
    # (2, 1).
    inv = bijectors.MatrixInverseTriL(validate_args=True)
    x_ = np.array([[[[1., 0.],
                     [2., 3.]]],
                   [[[4., 0.],
                     [5., -6.]]]], dtype=np.float32)
    x_inv_ = self._inv(x_)
    expected_fldj_ = -4. * np.sum(
        np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1)

    y = inv.forward(x_)
    x_back = inv.inverse(x_inv_)
    fldj = inv.forward_log_det_jacobian(x_, event_ndims=2)
    ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2)

    y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj])

    self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5)
    self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5)
    self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3)
    self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3)

  @test_util.run_in_graph_and_eager_modes
  def testErrorOnInputRankTooLow(self):
    inv = bijectors.MatrixInverseTriL(validate_args=True)
    x_ = np.array([0.1], dtype=np.float32)
    rank_error_msg = "must have rank at least 2"
    with self.test_session():
      with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
        inv.forward(x_).eval()
      with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
        inv.inverse(x_).eval()
      with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
        inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
      with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
        inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()

  # TODO(b/80481923): Figure out why these assertions fail, and fix them.
  ## def testErrorOnInputNonSquare(self):
  ##   inv = bijectors.MatrixInverseTriL(validate_args=True)
  ##   x_ = np.array([[1., 2., 3.],
  ##                  [4., 5., 6.]], dtype=np.float32)
  ##   square_error_msg = "must be a square matrix"
  ##   with self.test_session():
  ##     with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
  ##                                              square_error_msg):
  ##       inv.forward(x_).eval()
  ##     with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
  ##                                              square_error_msg):
  ##       inv.inverse(x_).eval()
  ##     with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
  ##                                              square_error_msg):
  ##       inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
  ##     with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
  ##                                              square_error_msg):
  ##       inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()

  @test_util.run_in_graph_and_eager_modes
  def testErrorOnInputNotLowerTriangular(self):
    inv = bijectors.MatrixInverseTriL(validate_args=True)
    x_ = np.array([[1., 2.],
                   [3., 4.]], dtype=np.float32)
    triangular_error_msg = "must be lower triangular"
    with self.test_session():
      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                               triangular_error_msg):
        inv.forward(x_).eval()
      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                               triangular_error_msg):
        inv.inverse(x_).eval()
      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                               triangular_error_msg):
        inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                               triangular_error_msg):
        inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()

  @test_util.run_in_graph_and_eager_modes
  def testErrorOnInputSingular(self):
    inv = bijectors.MatrixInverseTriL(validate_args=True)
    x_ = np.array([[1., 0.],
                   [0., 0.]], dtype=np.float32)
    nonsingular_error_msg = "must have all diagonal entries nonzero"
    with self.test_session():
      with self.assertRaisesOpError(nonsingular_error_msg):
        inv.forward(x_).eval()
      with self.assertRaisesOpError(nonsingular_error_msg):
        inv.inverse(x_).eval()
      with self.assertRaisesOpError(nonsingular_error_msg):
        inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
      with self.assertRaisesOpError(nonsingular_error_msg):
        inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()


if __name__ == "__main__":
  test.main()