diff options
author | RJ Ryan <rjryan@google.com> | 2017-07-20 13:36:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-20 13:41:22 -0700 |
commit | 97cf16bb937e0cea7c024b2e689b198795682101 (patch) | |
tree | 1c8edc24c7d8f855843c4d67c669369f7500203c /tensorflow/contrib/signal | |
parent | 9f4ec7e8492ae4527315a7ef2df4beb5cc8b9ab8 (diff) |
Documentation fixes and polishing for tf.contrib.signal.
PiperOrigin-RevId: 162658696
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/__init__.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/shape_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/spectral_ops.py | 48 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/window_ops.py | 12 |
4 files changed, 38 insertions, 33 deletions
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index c0136346b7..6cc51d6fb0 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""##Signal ops. +"""Signal processing operations. @@frame @@hamming_window @@ -20,6 +20,10 @@ @@inverse_stft @@overlap_and_add @@stft + +[hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window +[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window +[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ from __future__ import absolute_import diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py index dc7a073242..1ddc2941ec 100644 --- a/tensorflow/contrib/signal/python/ops/shape_ops.py +++ b/tensorflow/contrib/signal/python/ops/shape_ops.py @@ -57,7 +57,7 @@ def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1, name=None): """Expands `signal`'s `axis` dimension into frames of `frame_length`. - Slides a window of size `frame_length` over `signal`s `axis` dimension + Slides a window of size `frame_length` over `signal`'s `axis` dimension with a stride of `frame_step`, replacing the `axis` dimension with `[frames, frame_length]` frames. @@ -91,7 +91,8 @@ def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1, A `Tensor` of frames with shape `[..., frames, frame_length, ...]`. Raises: - ValueError: If `frame_length`, `frame_step`, or `pad_value` are not scalar. + ValueError: If `frame_length`, `frame_step`, `pad_value`, or `axis` are not + scalar. """ with ops.name_scope(name, "frame", [signal, frame_length, frame_step, pad_value]): diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py index 0d1ef0d3d1..75bc0bd21d 100644 --- a/tensorflow/contrib/signal/python/ops/spectral_ops.py +++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py @@ -32,17 +32,15 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops -def stft(signal, frame_length, frame_step, fft_length=None, +def stft(signals, frame_length, frame_step, fft_length=None, window_fn=functools.partial(window_ops.hann_window, periodic=True), pad_end=False, name=None): - """Computes the Short-time Fourier Transform of a batch of real signals. - - https://en.wikipedia.org/wiki/Short-time_Fourier_transform + """Computes the [Short-time Fourier Transform][stft] of `signals`. Implemented with GPU-compatible ops and supports gradients. Args: - signal: A `[..., samples]` `float32` `Tensor` of real-valued signals. + signals: A `[..., samples]` `float32` `Tensor` of real-valued signals. frame_length: An integer scalar `Tensor`. The window length in samples. frame_step: An integer scalar `Tensor`. The number of samples to step. fft_length: An integer scalar `Tensor`. The size of the FFT to apply. @@ -50,25 +48,26 @@ def stft(signal, frame_length, frame_step, fft_length=None, window_fn: A callable that takes a window length and a `dtype` keyword argument and returns a `[window_length]` `Tensor` of samples in the provided datatype. If set to `None`, no windowing is used. - pad_end: Whether to pad the end of signal with zeros when the provided - frame length and step produces a frame that lies partially past the end - of `signal`. + pad_end: Whether to pad the end of `signals` with zeros when the provided + frame length and step produces a frame that lies partially past its end. name: An optional name for the operation. Returns: A `[..., frames, fft_unique_bins]` `Tensor` of `complex64` STFT values where - `fft_unique_bins` is `fft_length / 2 + 1` (the unique components of the + `fft_unique_bins` is `fft_length // 2 + 1` (the unique components of the FFT). Raises: - ValueError: If `signal` is not at least rank 1, `frame_length` is + ValueError: If `signals` is not at least rank 1, `frame_length` is not scalar, `frame_step` is not scalar, or `frame_length` is greater than `fft_length`. + + [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ - with ops.name_scope(name, 'stft', [signal, frame_length, + with ops.name_scope(name, 'stft', [signals, frame_length, frame_step]): - signal = ops.convert_to_tensor(signal, name='signal') - signal.shape.with_rank_at_least(1) + signals = ops.convert_to_tensor(signals, name='signals') + signals.shape.with_rank_at_least(1) frame_length = ops.convert_to_tensor(frame_length, name='frame_length') frame_length.shape.assert_has_rank(0) frame_step = ops.convert_to_tensor(frame_step, name='frame_step') @@ -88,17 +87,17 @@ def stft(signal, frame_length, frame_step, fft_length=None, 'fft_length (%d)' % (frame_length_static, fft_length_static)) - framed_signal = shape_ops.frame( - signal, frame_length, frame_step, pad_end=pad_end) + framed_signals = shape_ops.frame( + signals, frame_length, frame_step, pad_end=pad_end) - # Optionally window the framed signal. + # Optionally window the framed signals. if window_fn is not None: - window = window_fn(frame_length, dtype=framed_signal.dtype) - framed_signal *= window + window = window_fn(frame_length, dtype=framed_signals.dtype) + framed_signals *= window # spectral_ops.rfft produces the (fft_length/2 + 1) unique components of the - # FFT of the real windowed signals in framed_signal. - return spectral_ops.rfft(framed_signal, [fft_length]) + # FFT of the real windowed signals in framed_signals. + return spectral_ops.rfft(framed_signals, [fft_length]) def inverse_stft(stfts, @@ -108,15 +107,14 @@ def inverse_stft(stfts, window_fn=functools.partial(window_ops.hann_window, periodic=True), name=None): - """Computes the inverse Short-time Fourier Transform of a batch of STFTs. - - https://en.wikipedia.org/wiki/Short-time_Fourier_transform + """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`. Implemented with GPU-compatible ops and supports gradients. Args: stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins - representing a batch of `fft_length`-point STFTs. + representing a batch of `fft_length`-point STFTs where `fft_unique_bins` + is `fft_length // 2 + 1` frame_length: An integer scalar `Tensor`. The window length in samples. frame_step: An integer scalar `Tensor`. The number of samples to step. fft_length: An integer scalar `Tensor`. The size of the FFT that produced @@ -133,6 +131,8 @@ def inverse_stft(stfts, Raises: ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, `frame_step` is not scalar, or `fft_length` is not scalar. + + [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ with ops.name_scope(name, 'inverse_stft', [stfts]): stfts = ops.convert_to_tensor(stfts, name='stfts') diff --git a/tensorflow/contrib/signal/python/ops/window_ops.py b/tensorflow/contrib/signal/python/ops/window_ops.py index 2179484ebc..07a847dd2a 100644 --- a/tensorflow/contrib/signal/python/ops/window_ops.py +++ b/tensorflow/contrib/signal/python/ops/window_ops.py @@ -29,9 +29,7 @@ from tensorflow.python.ops import math_ops def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): - """Generate a Hann window. - - https://en.wikipedia.org/wiki/Window_function#Hann_window + """Generate a [Hann window][hann]. Args: window_length: A scalar `Tensor` indicating the window length to generate. @@ -47,6 +45,8 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): Raises: ValueError: If `dtype` is not a floating point type. + + [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window """ return _raised_cosine_window(name, 'hann_window', window_length, periodic, dtype, 0.5, 0.5) @@ -54,9 +54,7 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): def hamming_window(window_length, periodic=True, dtype=dtypes.float32, name=None): - """Generate a Hamming window. - - https://en.wikipedia.org/wiki/Window_function#Hamming_window + """Generate a [Hamming][hamming] window. Args: window_length: A scalar `Tensor` indicating the window length to generate. @@ -72,6 +70,8 @@ def hamming_window(window_length, periodic=True, dtype=dtypes.float32, Raises: ValueError: If `dtype` is not a floating point type. + + [hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window """ return _raised_cosine_window(name, 'hamming_window', window_length, periodic, dtype, 0.54, 0.46) |