aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dct_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/dct_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dct_ops_test.py96
1 files changed, 69 insertions, 27 deletions
diff --git a/tensorflow/python/kernel_tests/dct_ops_test.py b/tensorflow/python/kernel_tests/dct_ops_test.py
index 93b2ff4561..97d7e2d8f9 100644
--- a/tensorflow/python/kernel_tests/dct_ops_test.py
+++ b/tensorflow/python/kernel_tests/dct_ops_test.py
@@ -40,50 +40,92 @@ def try_import(name): # pylint: disable=invalid-name
fftpack = try_import("scipy.fftpack")
+def _np_dct2(signals, norm=None):
+ """Computes the DCT-II manually with NumPy."""
+ # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
+ dct_size = signals.shape[-1]
+ dct = np.zeros_like(signals)
+ for k in range(dct_size):
+ phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
+ dct[..., k] = np.sum(signals * phi, axis=-1)
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ if norm == "ortho":
+ # The orthonormal scaling includes a factor of 0.5 which we combine with
+ # the overall scaling of 2.0 to cancel.
+ dct[..., 0] *= np.sqrt(1.0 / dct_size)
+ dct[..., 1:] *= np.sqrt(2.0 / dct_size)
+ else:
+ dct *= 2.0
+ return dct
+
+
+def _np_dct3(signals, norm=None):
+ """Computes the DCT-III manually with NumPy."""
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ dct_size = signals.shape[-1]
+ signals = np.array(signals) # make a copy so we can modify
+ if norm == "ortho":
+ signals[..., 0] *= np.sqrt(4.0 / dct_size)
+ signals[..., 1:] *= np.sqrt(2.0 / dct_size)
+ else:
+ signals *= 2.0
+ dct = np.zeros_like(signals)
+ # X_k = 0.5 * x_0 +
+ # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1
+ half_x0 = 0.5 * signals[..., 0]
+ for k in range(dct_size):
+ phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size)
+ dct[..., k] = half_x0 + np.sum(signals[..., 1:] * phi, axis=-1)
+ return dct
+
+
+NP_DCT = {2: _np_dct2, 3: _np_dct3}
+NP_IDCT = {2: _np_dct3, 3: _np_dct2}
+
+
class DCTOpsTest(test.TestCase):
- def _np_dct2(self, signals, norm=None):
- """Computes the DCT-II manually with NumPy."""
- # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
- dct_size = signals.shape[-1]
- dct = np.zeros_like(signals)
- for k in range(dct_size):
- phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
- dct[..., k] = np.sum(signals * phi, axis=-1)
- # SciPy's `dct` has a scaling factor of 2.0 which we follow.
- # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
- if norm == "ortho":
- # The orthonormal scaling includes a factor of 0.5 which we combine with
- # the overall scaling of 2.0 to cancel.
- dct[..., 0] *= np.sqrt(1.0 / dct_size)
- dct[..., 1:] *= np.sqrt(2.0 / dct_size)
- else:
- dct *= 2.0
- return dct
-
- def _compare(self, signals, norm, atol=5e-4, rtol=5e-4):
- """Compares the DCT to SciPy (if available) and a NumPy implementation."""
- np_dct = self._np_dct2(signals, norm)
- tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval()
+ def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4):
+ """Compares (I)DCT to SciPy (if available) and a NumPy implementation."""
+ np_dct = NP_DCT[dct_type](signals, norm)
+ tf_dct = spectral_ops.dct(signals, type=dct_type, norm=norm).eval()
self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol)
+ np_idct = NP_IDCT[dct_type](signals, norm)
+ tf_idct = spectral_ops.idct(signals, type=dct_type, norm=norm).eval()
+ self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
if fftpack:
- scipy_dct = fftpack.dct(signals, type=2, norm=norm)
+ scipy_dct = fftpack.dct(signals, type=dct_type, norm=norm)
self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
+ scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
+ self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol)
+ # Verify inverse(forward(s)) == s, up to a normalization factor.
+ tf_idct_dct = spectral_ops.idct(
+ tf_dct, type=dct_type, norm=norm).eval()
+ tf_dct_idct = spectral_ops.dct(
+ tf_idct, type=dct_type, norm=norm).eval()
+ if norm is None:
+ tf_idct_dct *= 0.5 / signals.shape[-1]
+ tf_dct_idct *= 0.5 / signals.shape[-1]
+ self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol)
+ self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
def test_random(self):
"""Test randomly generated batches of data."""
with spectral_ops_test_util.fft_kernel_label_map():
with self.test_session(use_gpu=True):
- for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]):
+ for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]):
signals = np.random.rand(*shape).astype(np.float32)
for norm in (None, "ortho"):
- self._compare(signals, norm)
+ self._compare(signals, norm, 2)
+ self._compare(signals, norm, 3)
def test_error(self):
signals = np.random.rand(10)
# Unsupported type.
with self.assertRaises(ValueError):
- spectral_ops.dct(signals, type=3)
+ spectral_ops.dct(signals, type=1)
# Unknown normalization.
with self.assertRaises(ValueError):
spectral_ops.dct(signals, norm="bad")