aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_identity_test.py
blob: 1a6d79e67dc957a1d7a4076b6db5af0d16bb7ca7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

from tensorflow.contrib.distributions.python.ops import operator_pd_identity
from tensorflow.contrib.distributions.python.ops import operator_test_util

distributions = tf.contrib.distributions


class OperatorPDIdentityTest(operator_test_util.OperatorPDDerivedClassTest):
  """Most tests done in the base class."""

  def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
    # Build an identity matrix with right shape and dtype.
    # Build an operator that should act the same way.
    batch_shape = list(batch_shape)
    diag_shape = batch_shape + [k]
    matrix_shape = batch_shape + [k, k]
    diag = tf.ones(diag_shape, dtype=dtype)
    scale = tf.constant(2.0, dtype=dtype)
    scaled_identity_matrix = scale * tf.matrix_diag(diag)
    operator = operator_pd_identity.OperatorPDIdentity(
        matrix_shape, dtype, scale=scale)
    return operator, scaled_identity_matrix.eval()

  def testBadDtypeArgsRaise(self):
    dtype = np.float32
    batch_shape = [2, 3]
    k = 4
    with self.test_session():
      operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)

      x_good_shape = batch_shape + [k, 5]
      x_good = self._rng.randn(*x_good_shape).astype(dtype)
      x_bad = x_good.astype(np.float64)

      operator.matmul(x_good).eval()  # Should not raise.

      with self.assertRaisesRegexp(TypeError, "dtype"):
        operator.matmul(x_bad)

      with self.assertRaisesRegexp(TypeError, "dtype"):
        operator.solve(x_bad)

      with self.assertRaisesRegexp(TypeError, "dtype"):
        operator.sqrt_solve(x_bad)

  def testBadRankArgsRaise(self):
    # Prepend a singleton dimension, changing the rank of "x", but not the size.
    dtype = np.float32
    batch_shape = [2, 3]
    k = 4
    with self.test_session():
      operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)

      x_good_shape = batch_shape + [k, 5]
      x_good = self._rng.randn(*x_good_shape).astype(dtype)
      x_bad = x_good.reshape(1, 2, 3, 4, 5)

      operator.matmul(x_good).eval()  # Should not raise.

      with self.assertRaisesRegexp(ValueError, "tensor rank"):
        operator.matmul(x_bad)

      with self.assertRaisesRegexp(ValueError, "tensor rank"):
        operator.solve(x_bad)

      with self.assertRaisesRegexp(ValueError, "tensor rank"):
        operator.sqrt_solve(x_bad)

  def testIncompatibleShapeArgsRaise(self):
    # Test shapes that are the same rank but incompatible for matrix
    # multiplication.
    dtype = np.float32
    batch_shape = [2, 3]
    k = 4
    with self.test_session():
      operator, _ = self._build_operator_and_mat(batch_shape, k, dtype=dtype)

      x_good_shape = batch_shape + [k, 5]
      x_good = self._rng.randn(*x_good_shape).astype(dtype)
      x_bad_shape = batch_shape + [5, k]
      x_bad = x_good.reshape(*x_bad_shape)

      operator.matmul(x_good).eval()  # Should not raise.

      with self.assertRaisesRegexp(ValueError, "Incompatible"):
        operator.matmul(x_bad)

      with self.assertRaisesRegexp(ValueError, "Incompatible"):
        operator.solve(x_bad)

      with self.assertRaisesRegexp(ValueError, "Incompatible"):
        operator.sqrt_solve(x_bad)


if __name__ == "__main__":
  tf.test.main()