1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for window_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
from tensorflow.contrib.signal.python.kernel_tests import test_util
from tensorflow.contrib.signal.python.ops import window_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
"""A simple implementation of a raised cosine window that matches SciPy.
https://en.wikipedia.org/wiki/Window_function#Hann_window
https://github.com/scipy/scipy/blob/v0.14.0/scipy/signal/windows.py#L615
Args:
length: The window length.
symmetric: Whether to create a symmetric window.
a: The alpha parameter of the raised cosine window.
b: The beta parameter of the raised cosine window.
Returns:
A raised cosine window of length `length`.
"""
if length == 1:
return np.ones(1)
odd = length % 2
if not symmetric and not odd:
length += 1
window = a - b * np.cos(2.0 * np.pi * np.arange(length) / (length - 1))
if not symmetric and not odd:
window = window[:-1]
return window
class WindowOpsTest(test.TestCase):
def setUp(self):
self._window_lengths = [1, 2, 3, 4, 5, 31, 64, 128]
self._dtypes = [(dtypes.float16, 1e-2),
(dtypes.float32, 1e-6),
(dtypes.float64, 1e-9)]
def _compare_window_fns(self, np_window_fn, tf_window_fn):
with self.test_session(use_gpu=True):
for window_length in self._window_lengths:
for periodic in [False, True]:
for tf_dtype, tol in self._dtypes:
np_dtype = tf_dtype.as_numpy_dtype
expected = np_window_fn(window_length,
symmetric=not periodic).astype(np_dtype)
actual = tf_window_fn(window_length, periodic=periodic,
dtype=tf_dtype).eval()
self.assertAllClose(expected, actual, tol, tol)
def test_hann_window(self):
"""Check that hann_window matches scipy.signal.hann behavior."""
# The Hann window is a raised cosine window with parameters alpha=0.5 and
# beta=0.5.
# https://en.wikipedia.org/wiki/Window_function#Hann_window
self._compare_window_fns(
functools.partial(_scipy_raised_cosine, a=0.5, b=0.5),
window_ops.hann_window)
def test_hamming_window(self):
"""Check that hamming_window matches scipy.signal.hamming's behavior."""
# The Hamming window is a raised cosine window with parameters alpha=0.54
# and beta=0.46.
# https://en.wikipedia.org/wiki/Window_function#Hamming_window
self._compare_window_fns(
functools.partial(_scipy_raised_cosine, a=0.54, b=0.46),
window_ops.hamming_window)
def test_constant_folding(self):
"""Window functions should be constant foldable for constant inputs."""
for window_fn in (window_ops.hann_window, window_ops.hamming_window):
for dtype, _ in self._dtypes:
for periodic in [False, True]:
g = ops.Graph()
with g.as_default():
window = window_fn(100, periodic=periodic, dtype=dtype)
rewritten_graph = test_util.grappler_optimize(g, [window])
self.assertEqual(1, len(rewritten_graph.node))
if __name__ == '__main__':
test.main()
|