aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
blob: 49b9f1feb8e9feaa55014fc4ac8fd647255b19b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for testing `LinearOperator` and sub-classes."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import numpy as np
import six

from tensorflow.contrib.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test


@six.add_metaclass(abc.ABCMeta)  # pylint: disable=no-init
class LinearOperatorDerivedClassTest(test.TestCase):
  """Tests for derived classes.

  Subclasses should implement every abstractmethod, and this will enable all
  test methods to work.
  """

  # Absolute/relative tolerance for tests.
  _atol = {
      dtypes.float16: 1e-3,
      dtypes.float32: 1e-6,
      dtypes.float64: 1e-12,
      dtypes.complex64: 1e-6,
      dtypes.complex128: 1e-12
  }
  _rtol = {
      dtypes.float16: 1e-3,
      dtypes.float32: 1e-6,
      dtypes.float64: 1e-12,
      dtypes.complex64: 1e-6,
      dtypes.complex128: 1e-12
  }

  def assertAC(self, x, y):
    """Derived classes can set _atol, _rtol to get different tolerance."""
    dtype = dtypes.as_dtype(x.dtype)
    atol = self._atol[dtype]
    rtol = self._rtol[dtype]
    self.assertAllClose(x, y, atol=atol, rtol=rtol)

  @property
  def _dtypes_to_test(self):
    # TODO(langmore) Test tf.float16 once tf.matrix_solve works in 16bit.
    return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]

  @abc.abstractproperty
  def _shapes_to_test(self):
    """Returns list of tuples, each is one shape that will be tested."""
    raise NotImplementedError("shapes_to_test has not been implemented.")

  @abc.abstractmethod
  def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
    """Build a batch matrix and an Operator that should have similar behavior.

    Every operator acts like a (batch) matrix.  This method returns both
    together, and is used by tests.

    Args:
      shape:  List-like of Python integers giving full shape of operator.
      dtype:  Numpy dtype.  Data type of returned array/operator.
      use_placeholder:  Python bool.  If True, initialize the operator with a
        placeholder of undefined shape and correct dtype.

    Returns:
      operator:  `LinearOperator` subclass instance.
      mat:  `Tensor` representing operator.
      feed_dict:  Dictionary.
        If placholder is True, this must contains everything needed to be fed
          to sess.run calls at runtime to make the operator work.
    """
    # Create a matrix as a numpy array with desired shape/dtype.
    # Create a LinearOperator that should have the same behavior as the matrix.
    raise NotImplementedError("Not implemented yet.")

  @abc.abstractmethod
  def _make_rhs(self, operator, adjoint):
    """Make a rhs appropriate for calling operator.solve(rhs).

    Args:
      operator:  A `LinearOperator`
      adjoint:  Python `bool`.  If `True`, we are making a 'rhs' value for the
        adjoint operator.

    Returns:
      A `Tensor`
    """
    raise NotImplementedError("_make_rhs is not defined.")

  @abc.abstractmethod
  def _make_x(self, operator, adjoint):
    """Make an 'x' appropriate for calling operator.apply(x).

    Args:
      operator:  A `LinearOperator`
      adjoint:  Python `bool`.  If `True`, we are making an 'x' value for the
        adjoint operator.

    Returns:
      A `Tensor`
    """
    raise NotImplementedError("_make_x is not defined.")

  @property
  def _tests_to_skip(self):
    """List of test names to skip."""
    # Subclasses should over-ride if they want to skip some tests.
    # To skip "test_foo", add "foo" to this list.
    return []

  def _maybe_skip(self, test_name):
    if test_name in self._tests_to_skip:
      self.skipTest("%s skipped because it was added to self._tests_to_skip.")

  def test_to_dense(self):
    self._maybe_skip("to_dense")
    with self.test_session() as sess:
      for use_placeholder in False, True:
        for shape in self._shapes_to_test:
          for dtype in self._dtypes_to_test:
            operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                shape, dtype, use_placeholder=use_placeholder)
            op_dense = operator.to_dense()
            if not use_placeholder:
              self.assertAllEqual(shape, op_dense.get_shape())
            op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict)
            self.assertAC(op_dense_v, mat_v)

  def test_det(self):
    self._maybe_skip("det")
    with self.test_session() as sess:
      for use_placeholder in False, True:
        for shape in self._shapes_to_test:
          for dtype in self._dtypes_to_test:
            if dtype.is_complex:
              self.skipTest(
                  "tf.matrix_determinant does not work with complex, so this "
                  "test is being skipped.")
            operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                shape, dtype, use_placeholder=use_placeholder)
            op_det = operator.determinant()
            if not use_placeholder:
              self.assertAllEqual(shape[:-2], op_det.get_shape())
            op_det_v, mat_det_v = sess.run(
                [op_det, linalg_ops.matrix_determinant(mat)],
                feed_dict=feed_dict)
            self.assertAC(op_det_v, mat_det_v)

  def test_apply(self):
    self._maybe_skip("apply")
    with self.test_session() as sess:
      for use_placeholder in False, True:
        for shape in self._shapes_to_test:
          for dtype in self._dtypes_to_test:
            for adjoint in False, True:
              operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                  shape, dtype, use_placeholder=use_placeholder)
              x = self._make_x(operator, adjoint=adjoint)
              op_apply = operator.apply(x, adjoint=adjoint)
              mat_apply = math_ops.matmul(mat, x, adjoint_a=adjoint)
              if not use_placeholder:
                self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape())
              op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply],
                                                 feed_dict=feed_dict)
              self.assertAC(op_apply_v, mat_apply_v)

  def test_solve(self):
    self._maybe_skip("solve")
    with self.test_session() as sess:
      for use_placeholder in False, True:
        for shape in self._shapes_to_test:
          for dtype in self._dtypes_to_test:
            for adjoint in False, True:
              operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                  shape, dtype, use_placeholder=use_placeholder)
              rhs = self._make_rhs(operator, adjoint=adjoint)
              op_solve = operator.solve(rhs, adjoint=adjoint)
              mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
              if not use_placeholder:
                self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape())
              op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve],
                                                 feed_dict=feed_dict)
              self.assertAC(op_solve_v, mat_solve_v)

  def test_add_to_tensor(self):
    self._maybe_skip("add_to_tensor")
    with self.test_session() as sess:
      for use_placeholder in False, True:
        for shape in self._shapes_to_test:
          for dtype in self._dtypes_to_test:
            operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                shape, dtype, use_placeholder=use_placeholder)
            op_plus_2mat = operator.add_to_tensor(2 * mat)

            if not use_placeholder:
              self.assertAllEqual(shape, op_plus_2mat.get_shape())

            op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat],
                                             feed_dict=feed_dict)

            self.assertAC(op_plus_2mat_v, 3 * mat_v)


@six.add_metaclass(abc.ABCMeta)
class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
  """Base test class appropriate for square operators.

  Sub-classes must still define all abstractmethods from
  LinearOperatorDerivedClassTest that are not defined here.
  """

  @property
  def _shapes_to_test(self):
    # non-batch operators (n, n) and batch operators.
    return [(0, 0), (1, 1), (1, 3, 3), (3, 4, 4), (2, 1, 4, 4)]

  def _make_rhs(self, operator, adjoint):
    # This operator is square, so rhs and x will have same shape.
    # adjoint value makes no difference because the operator shape doesn't
    # change since it is square, but be pedantic.
    return self._make_x(operator, adjoint=not adjoint)

  def _make_x(self, operator, adjoint):
    # Value of adjoint makes no difference because the operator is square.
    # Return the number of systems to solve, R, equal to 1 or 2.
    r = self._get_num_systems(operator)
    # If operator.shape = [B1,...,Bb, N, N] this returns a random matrix of
    # shape [B1,...,Bb, N, R], R = 1 or 2.
    if operator.shape.is_fully_defined():
      batch_shape = operator.batch_shape.as_list()
      n = operator.domain_dimension.value
      x_shape = batch_shape + [n, r]
    else:
      batch_shape = operator.batch_shape_dynamic()
      n = operator.domain_dimension_dynamic()
      x_shape = array_ops.concat_v2((batch_shape, [n, r]), 0)

    return random_normal(x_shape, dtype=operator.dtype)

  def _get_num_systems(self, operator):
    """Get some number, either 1 or 2, depending on operator."""
    if operator.tensor_rank is None or operator.tensor_rank % 2:
      return 1
    else:
      return 2


@six.add_metaclass(abc.ABCMeta)
class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
  """Base test class appropriate for generic rectangular operators.

  Square shapes are never tested by this class, so if you want to test your
  operator with a square shape, create two test classes, the other subclassing
  SquareLinearOperatorMatrixTest.

  Sub-classes must still define all abstractmethods from
  LinearOperatorDerivedClassTest that are not defined here.
  """

  @property
  def _tests_to_skip(self):
    """List of test names to skip."""
    return ["solve", "det"]

  @property
  def _shapes_to_test(self):
    # non-batch operators (n, n) and batch operators.
    return [(2, 1), (1, 2), (1, 3, 2), (3, 3, 4), (2, 1, 2, 4)]

  def _make_rhs(self, operator, adjoint):
    # TODO(langmore) Add once we're testing solve_ls.
    raise NotImplementedError(
        "_make_rhs not implemented because we don't test solve")

  def _make_x(self, operator, adjoint):
    # Return the number of systems for the argument 'x' for .apply(x)
    r = self._get_num_systems(operator)
    # If operator.shape = [B1,...,Bb, M, N] this returns a random matrix of
    # shape [B1,...,Bb, N, R], R = 1 or 2.
    if operator.shape.is_fully_defined():
      batch_shape = operator.batch_shape.as_list()
      if adjoint:
        n = operator.range_dimension.value
      else:
        n = operator.domain_dimension.value
      x_shape = batch_shape + [n, r]
    else:
      batch_shape = operator.batch_shape_dynamic()
      if adjoint:
        n = operator.range_dimension_dynamic()
      else:
        n = operator.domain_dimension_dynamic()
      x_shape = array_ops.concat_v2((batch_shape, [n, r]), 0)

    return random_normal(x_shape, dtype=operator.dtype)

  def _get_num_systems(self, operator):
    """Get some number, either 1 or 2, depending on operator."""
    if operator.tensor_rank is None or operator.tensor_rank % 2:
      return 1
    else:
      return 2


def random_positive_definite_matrix(shape, dtype, force_well_conditioned=False):
  """[batch] positive definite matrix.

  Args:
    shape:  `TensorShape` or Python list.  Shape of the returned matrix.
    dtype:  `TensorFlow` `dtype` or Python dtype.
    force_well_conditioned:  Python bool.  If `True`, returned matrix has
      eigenvalues with modulus in `(1, 4)`.  Otherwise, eigenvalues are
      chi-squared random variables.

  Returns:
    `Tensor` with desired shape and dtype.
  """
  dtype = dtypes.as_dtype(dtype)
  if not contrib_tensor_util.is_tensor(shape):
    shape = tensor_shape.TensorShape(shape)
    # Matrix must be square.
    shape[-1].assert_is_compatible_with(shape[-2])

  with ops.name_scope("random_positive_definite_matrix"):
    tril = random_tril_matrix(
        shape, dtype, force_well_conditioned=force_well_conditioned)
    return math_ops.matmul(tril, tril, adjoint_b=True)


def random_tril_matrix(shape,
                       dtype,
                       force_well_conditioned=False,
                       remove_upper=True):
  """[batch] lower triangular matrix.

  Args:
    shape:  `TensorShape` or Python `list`.  Shape of the returned matrix.
    dtype:  `TensorFlow` `dtype` or Python dtype
    force_well_conditioned:  Python `bool`. If `True`, returned matrix will have
      eigenvalues with modulus in `(1, 2)`.  Otherwise, eigenvalues are unit
      normal random variables.
    remove_upper:  Python `bool`.
      If `True`, zero out the strictly upper triangle.
      If `False`, the lower triangle of returned matrix will have desired
      properties, but will not not have the strictly upper triangle zero'd out.

  Returns:
    `Tensor` with desired shape and dtype.
  """
  with ops.name_scope("random_tril_matrix"):
    # Totally random matrix.  Has no nice properties.
    tril = random_normal(shape, dtype=dtype)
    if remove_upper:
      tril = array_ops.matrix_band_part(tril, -1, 0)

    # Create a diagonal with entries having modulus in [1, 2].
    if force_well_conditioned:
      maxval = ops.convert_to_tensor(np.sqrt(2.), dtype=dtype.real_dtype)
      diag = random_sign_uniform(
          shape[:-1], dtype=dtype, minval=1., maxval=maxval)
      tril = array_ops.matrix_set_diag(tril, diag)

    return tril


def random_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, seed=None):
  """Tensor with (possibly complex) Gaussian entries.

  Samples are distributed like

  ```
  N(mean, stddev^2), if dtype is real,
  X + iY,  where X, Y ~ N(mean, stddev^2) if dtype is complex.
  ```

  Args:
    shape:  `TensorShape` or Python list.  Shape of the returned tensor.
    mean:  `Tensor` giving mean of normal to sample from.
    stddev:  `Tensor` giving stdev of normal to sample from.
    dtype:  `TensorFlow` `dtype` or numpy dtype
    seed:  Python integer seed for the RNG.

  Returns:
    `Tensor` with desired shape and dtype.
  """
  dtype = dtypes.as_dtype(dtype)

  with ops.name_scope("random_normal"):
    samples = random_ops.random_normal(
        shape, mean=mean, stddev=stddev, dtype=dtype.real_dtype, seed=seed)
    if dtype.is_complex:
      if seed is not None:
        seed += 1234
      more_samples = random_ops.random_normal(
          shape, mean=mean, stddev=stddev, dtype=dtype.real_dtype, seed=seed)
      samples = math_ops.complex(samples, more_samples)
    return samples


def random_uniform(shape,
                   minval=None,
                   maxval=None,
                   dtype=dtypes.float32,
                   seed=None):
  """Tensor with (possibly complex) Uniform entries.

  Samples are distributed like

  ```
  Uniform[minval, maxval], if dtype is real,
  X + iY,  where X, Y ~ Uniform[minval, maxval], if dtype is complex.
  ```

  Args:
    shape:  `TensorShape` or Python list.  Shape of the returned tensor.
    minval:  `0-D` `Tensor` giving the minimum values.
    maxval:  `0-D` `Tensor` giving the maximum values.
    dtype:  `TensorFlow` `dtype` or Python dtype
    seed:  Python integer seed for the RNG.

  Returns:
    `Tensor` with desired shape and dtype.
  """
  dtype = dtypes.as_dtype(dtype)

  with ops.name_scope("random_uniform"):
    samples = random_ops.random_uniform(
        shape, dtype=dtype.real_dtype, minval=minval, maxval=maxval, seed=seed)
    if dtype.is_complex:
      if seed is not None:
        seed += 12345
      more_samples = random_ops.random_uniform(
          shape,
          dtype=dtype.real_dtype,
          minval=minval,
          maxval=maxval,
          seed=seed)
      samples = math_ops.complex(samples, more_samples)
    return samples


def random_sign_uniform(shape,
                        minval=None,
                        maxval=None,
                        dtype=dtypes.float32,
                        seed=None):
  """Tensor with (possibly complex) random entries from a "sign Uniform".

  Letting `Z` be a random variable equal to `-1` and `1` with equal probability,
  Samples from this `Op` are distributed like

  ```
  Z * X, where X ~ Uniform[minval, maxval], if dtype is real,
  Z * (X + iY),  where X, Y ~ Uniform[minval, maxval], if dtype is complex.
  ```

  Args:
    shape:  `TensorShape` or Python list.  Shape of the returned tensor.
    minval:  `0-D` `Tensor` giving the minimum values.
    maxval:  `0-D` `Tensor` giving the maximum values.
    dtype:  `TensorFlow` `dtype` or Python dtype
    seed:  Python integer seed for the RNG.

  Returns:
    `Tensor` with desired shape and dtype.
  """
  dtype = dtypes.as_dtype(dtype)

  with ops.name_scope("random_sign_uniform"):
    unsigned_samples = random_uniform(
        shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
    if seed is not None:
      seed += 12
    signs = math_ops.sign(
        random_ops.random_uniform(
            shape, minval=-1., maxval=1., seed=seed))
    return unsigned_samples * math_ops.cast(signs, unsigned_samples.dtype)