aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/spectral_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/spectral_ops.py')
-rw-r--r--tensorflow/python/ops/spectral_ops.py125
1 files changed, 98 insertions, 27 deletions
diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py
index 28054f50ef..293aace728 100644
--- a/tensorflow/python/ops/spectral_ops.py
+++ b/tensorflow/python/ops/spectral_ops.py
@@ -167,8 +167,8 @@ def _validate_dct_arguments(dct_type, n, axis, norm):
raise NotImplementedError("The DCT length argument is not implemented.")
if axis != -1:
raise NotImplementedError("axis must be -1. Got: %s" % axis)
- if dct_type != 2:
- raise ValueError("Only the Type II DCT is supported.")
+ if dct_type not in (2, 3):
+ raise ValueError("Only Types II and III (I)DCT are supported.")
if norm not in (None, "ortho"):
raise ValueError(
"Unknown normalization. Expected None or 'ortho', got: %s" % norm)
@@ -179,18 +179,20 @@ def _validate_dct_arguments(dct_type, n, axis, norm):
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
- Currently only Type II is supported. Implemented using a length `2N` padded
- @{tf.spectral.rfft}, as described here: https://dsp.stackexchange.com/a/10606
+ Currently only Types II and III are supported. Type II is implemented using a
+ length `2N` padded @{tf.spectral.rfft}, as described here:
+ https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward
+ inverse of Type II (i.e. using a length `2N` padded @{tf.spectral.irfft}).
@compatibility(scipy)
- Equivalent to scipy.fftpack.dct for the Type-II DCT.
+ Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT.
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
@end_compatibility
Args:
input: A `[..., samples]` `float32` `Tensor` containing the signals to
take the DCT of.
- type: The DCT type to perform. Must be 2.
+ type: The DCT type to perform. Must be 2 or 3.
n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'`
@@ -201,8 +203,8 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
A `[..., samples]` `float32` `Tensor` containing the DCT of `input`.
Raises:
- ValueError: If `type` is not `2`, `n` is not `None, `axis` is not `-1`, or
- `norm` is not `None` or `'ortho'`.
+ ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
+ `-1`, or `norm` is not `None` or `'ortho'`.
[dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
"""
@@ -214,22 +216,91 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
axis_dim = input.shape[-1].value or _array_ops.shape(input)[-1]
axis_dim_float = _math_ops.to_float(axis_dim)
- scale = 2.0 * _math_ops.exp(_math_ops.complex(
- 0.0, -_math.pi * _math_ops.range(axis_dim_float) /
- (2.0 * axis_dim_float)))
-
- # TODO(rjryan): Benchmark performance and memory usage of the various
- # approaches to computing a DCT via the RFFT.
- dct2 = _math_ops.real(
- rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
-
- if norm == "ortho":
- n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
- n2 = n1 * _math_ops.sqrt(2.0)
- # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
- weights = _array_ops.pad(
- _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
- constant_values=n2)
- dct2 *= weights
-
- return dct2
+ if type == 2:
+ scale = 2.0 * _math_ops.exp(
+ _math_ops.complex(
+ 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 /
+ axis_dim_float))
+
+ # TODO(rjryan): Benchmark performance and memory usage of the various
+ # approaches to computing a DCT via the RFFT.
+ dct2 = _math_ops.real(
+ rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
+
+ if norm == "ortho":
+ n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
+ n2 = n1 * _math_ops.sqrt(2.0)
+ # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
+ weights = _array_ops.pad(
+ _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
+ constant_values=n2)
+ dct2 *= weights
+
+ return dct2
+
+ elif type == 3:
+ if norm == "ortho":
+ n1 = _math_ops.sqrt(axis_dim_float)
+ n2 = n1 * _math_ops.sqrt(0.5)
+ # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
+ weights = _array_ops.pad(
+ _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
+ constant_values=n2)
+ input *= weights
+ else:
+ input *= axis_dim_float
+ scale = 2.0 * _math_ops.exp(
+ _math_ops.complex(
+ 0.0,
+ _math_ops.range(axis_dim_float) * _math.pi * 0.5 /
+ axis_dim_float))
+ dct3 = _math_ops.real(
+ irfft(
+ scale * _math_ops.complex(input, 0.0),
+ fft_length=[2 * axis_dim]))[..., :axis_dim]
+
+ return dct3
+
+
+# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
+@tf_export("spectral.idct")
+def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
+ """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
+
+ Currently only Types II and III are supported. Type III is the inverse of
+ Type II, and vice versa.
+
+ Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
+ not `'ortho'`. That is:
+ `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`.
+ When `norm='ortho'`, we have:
+ `signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
+
+ @compatibility(scipy)
+ Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT.
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html
+ @end_compatibility
+
+ Args:
+ input: A `[..., samples]` `float32` `Tensor` containing the signals to take
+ the DCT of.
+ type: The IDCT type to perform. Must be 2 or 3.
+ n: For future expansion. The length of the transform. Must be `None`.
+ axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
+ norm: The normalization to apply. `None` for no normalization or `'ortho'`
+ for orthonormal normalization.
+ name: An optional name for the operation.
+
+ Returns:
+ A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`.
+
+ Raises:
+ ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
+ `-1`, or `norm` is not `None` or `'ortho'`.
+
+ [idct]:
+ https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
+ """
+ _validate_dct_arguments(type, n, axis, norm)
+ inverse_type = {2: 3, 3: 2}[type]
+ return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)