aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py
blob: 5a464699dac5a737e0c6e0122a4a6699e945f695 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for window_ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

import numpy as np

from tensorflow.contrib.signal.python.kernel_tests import test_util
from tensorflow.contrib.signal.python.ops import window_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import test


def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
  """A simple implementation of a raised cosine window that matches SciPy.

  https://en.wikipedia.org/wiki/Window_function#Hann_window
  https://github.com/scipy/scipy/blob/v0.14.0/scipy/signal/windows.py#L615

  Args:
    length: The window length.
    symmetric: Whether to create a symmetric window.
    a: The alpha parameter of the raised cosine window.
    b: The beta parameter of the raised cosine window.

  Returns:
    A raised cosine window of length `length`.
  """
  if length == 1:
    return np.ones(1)
  odd = length % 2
  if not symmetric and not odd:
    length += 1
  window = a - b * np.cos(2.0 * np.pi * np.arange(length) / (length - 1))
  if not symmetric and not odd:
    window = window[:-1]
  return window


class WindowOpsTest(test.TestCase):

  def setUp(self):
    self._window_lengths = [1, 2, 3, 4, 5, 31, 64, 128]
    self._dtypes = [(dtypes.float16, 1e-2),
                    (dtypes.float32, 1e-6),
                    (dtypes.float64, 1e-9)]

  def _compare_window_fns(self, np_window_fn, tf_window_fn):
    with self.test_session(use_gpu=True):
      for window_length in self._window_lengths:
        for periodic in [False, True]:
          for tf_dtype, tol in self._dtypes:
            np_dtype = tf_dtype.as_numpy_dtype
            expected = np_window_fn(window_length,
                                    symmetric=not periodic).astype(np_dtype)
            actual = tf_window_fn(window_length, periodic=periodic,
                                  dtype=tf_dtype).eval()
            self.assertAllClose(expected, actual, tol, tol)

  def test_hann_window(self):
    """Check that hann_window matches scipy.signal.hann behavior."""
    # The Hann window is a raised cosine window with parameters alpha=0.5 and
    # beta=0.5.
    # https://en.wikipedia.org/wiki/Window_function#Hann_window
    self._compare_window_fns(
        functools.partial(_scipy_raised_cosine, a=0.5, b=0.5),
        window_ops.hann_window)

  def test_hamming_window(self):
    """Check that hamming_window matches scipy.signal.hamming's behavior."""
    # The Hamming window is a raised cosine window with parameters alpha=0.54
    # and beta=0.46.
    # https://en.wikipedia.org/wiki/Window_function#Hamming_window
    self._compare_window_fns(
        functools.partial(_scipy_raised_cosine, a=0.54, b=0.46),
        window_ops.hamming_window)

  def test_constant_folding(self):
    """Window functions should be constant foldable for constant inputs."""
    for window_fn in (window_ops.hann_window, window_ops.hamming_window):
      for dtype, _ in self._dtypes:
        for periodic in [False, True]:
          g = ops.Graph()
          with g.as_default():
            window = window_fn(100, periodic=periodic, dtype=dtype)
            rewritten_graph = test_util.grappler_optimize(g, [window])
            self.assertEqual(1, len(rewritten_graph.node))


if __name__ == '__main__':
  test.main()