aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/spectral_ops_test_util.py
blob: 1f2e730edc8f57582a2f2075fdc2d8614e6b9582 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for writing test involving spectral_ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.platform import test


def _use_eigen_kernels():
  use_eigen_kernels = False  # Eigen kernels are default
  if test.is_gpu_available(cuda_only=True):
    use_eigen_kernels = False
  return use_eigen_kernels


def fft_kernel_label_map():
  """Returns a generator overriding kernel selection.

  This is used to force testing of the eigen kernels, even
  when they are not the default registered kernels.

  Returns:
    A generator in which to wrap every test.
  """
  if _use_eigen_kernels():
    d = dict([(op, "eigen")
              for op in [
                  "FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D",
                  "IRFFT", "IRFFT2D", "IRFFT3D", "RFFT", "RFFT2D", "RFFT3D"
              ]])
    return ops.get_default_graph()._kernel_label_map(d)  # pylint: disable=protected-access
  else:
    return ops.get_default_graph()._kernel_label_map({})  # pylint: disable=protected-access