aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-20 13:36:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-20 13:41:22 -0700
commit97cf16bb937e0cea7c024b2e689b198795682101 (patch)
tree1c8edc24c7d8f855843c4d67c669369f7500203c /tensorflow/contrib/signal
parent9f4ec7e8492ae4527315a7ef2df4beb5cc8b9ab8 (diff)
Documentation fixes and polishing for tf.contrib.signal.
PiperOrigin-RevId: 162658696
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/__init__.py6
-rw-r--r--tensorflow/contrib/signal/python/ops/shape_ops.py5
-rw-r--r--tensorflow/contrib/signal/python/ops/spectral_ops.py48
-rw-r--r--tensorflow/contrib/signal/python/ops/window_ops.py12
4 files changed, 38 insertions, 33 deletions
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index c0136346b7..6cc51d6fb0 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""##Signal ops.
+"""Signal processing operations.
@@frame
@@hamming_window
@@ -20,6 +20,10 @@
@@inverse_stft
@@overlap_and_add
@@stft
+
+[hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window
+[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window
+[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py
index dc7a073242..1ddc2941ec 100644
--- a/tensorflow/contrib/signal/python/ops/shape_ops.py
+++ b/tensorflow/contrib/signal/python/ops/shape_ops.py
@@ -57,7 +57,7 @@ def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
name=None):
"""Expands `signal`'s `axis` dimension into frames of `frame_length`.
- Slides a window of size `frame_length` over `signal`s `axis` dimension
+ Slides a window of size `frame_length` over `signal`'s `axis` dimension
with a stride of `frame_step`, replacing the `axis` dimension with
`[frames, frame_length]` frames.
@@ -91,7 +91,8 @@ def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
A `Tensor` of frames with shape `[..., frames, frame_length, ...]`.
Raises:
- ValueError: If `frame_length`, `frame_step`, or `pad_value` are not scalar.
+ ValueError: If `frame_length`, `frame_step`, `pad_value`, or `axis` are not
+ scalar.
"""
with ops.name_scope(name, "frame", [signal, frame_length, frame_step,
pad_value]):
diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py
index 0d1ef0d3d1..75bc0bd21d 100644
--- a/tensorflow/contrib/signal/python/ops/spectral_ops.py
+++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py
@@ -32,17 +32,15 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops
-def stft(signal, frame_length, frame_step, fft_length=None,
+def stft(signals, frame_length, frame_step, fft_length=None,
window_fn=functools.partial(window_ops.hann_window, periodic=True),
pad_end=False, name=None):
- """Computes the Short-time Fourier Transform of a batch of real signals.
-
- https://en.wikipedia.org/wiki/Short-time_Fourier_transform
+ """Computes the [Short-time Fourier Transform][stft] of `signals`.
Implemented with GPU-compatible ops and supports gradients.
Args:
- signal: A `[..., samples]` `float32` `Tensor` of real-valued signals.
+ signals: A `[..., samples]` `float32` `Tensor` of real-valued signals.
frame_length: An integer scalar `Tensor`. The window length in samples.
frame_step: An integer scalar `Tensor`. The number of samples to step.
fft_length: An integer scalar `Tensor`. The size of the FFT to apply.
@@ -50,25 +48,26 @@ def stft(signal, frame_length, frame_step, fft_length=None,
window_fn: A callable that takes a window length and a `dtype` keyword
argument and returns a `[window_length]` `Tensor` of samples in the
provided datatype. If set to `None`, no windowing is used.
- pad_end: Whether to pad the end of signal with zeros when the provided
- frame length and step produces a frame that lies partially past the end
- of `signal`.
+ pad_end: Whether to pad the end of `signals` with zeros when the provided
+ frame length and step produces a frame that lies partially past its end.
name: An optional name for the operation.
Returns:
A `[..., frames, fft_unique_bins]` `Tensor` of `complex64` STFT values where
- `fft_unique_bins` is `fft_length / 2 + 1` (the unique components of the
+ `fft_unique_bins` is `fft_length // 2 + 1` (the unique components of the
FFT).
Raises:
- ValueError: If `signal` is not at least rank 1, `frame_length` is
+ ValueError: If `signals` is not at least rank 1, `frame_length` is
not scalar, `frame_step` is not scalar, or `frame_length`
is greater than `fft_length`.
+
+ [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
- with ops.name_scope(name, 'stft', [signal, frame_length,
+ with ops.name_scope(name, 'stft', [signals, frame_length,
frame_step]):
- signal = ops.convert_to_tensor(signal, name='signal')
- signal.shape.with_rank_at_least(1)
+ signals = ops.convert_to_tensor(signals, name='signals')
+ signals.shape.with_rank_at_least(1)
frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
frame_length.shape.assert_has_rank(0)
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
@@ -88,17 +87,17 @@ def stft(signal, frame_length, frame_step, fft_length=None,
'fft_length (%d)' % (frame_length_static,
fft_length_static))
- framed_signal = shape_ops.frame(
- signal, frame_length, frame_step, pad_end=pad_end)
+ framed_signals = shape_ops.frame(
+ signals, frame_length, frame_step, pad_end=pad_end)
- # Optionally window the framed signal.
+ # Optionally window the framed signals.
if window_fn is not None:
- window = window_fn(frame_length, dtype=framed_signal.dtype)
- framed_signal *= window
+ window = window_fn(frame_length, dtype=framed_signals.dtype)
+ framed_signals *= window
# spectral_ops.rfft produces the (fft_length/2 + 1) unique components of the
- # FFT of the real windowed signals in framed_signal.
- return spectral_ops.rfft(framed_signal, [fft_length])
+ # FFT of the real windowed signals in framed_signals.
+ return spectral_ops.rfft(framed_signals, [fft_length])
def inverse_stft(stfts,
@@ -108,15 +107,14 @@ def inverse_stft(stfts,
window_fn=functools.partial(window_ops.hann_window,
periodic=True),
name=None):
- """Computes the inverse Short-time Fourier Transform of a batch of STFTs.
-
- https://en.wikipedia.org/wiki/Short-time_Fourier_transform
+ """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`.
Implemented with GPU-compatible ops and supports gradients.
Args:
stfts: A `complex64` `[..., frames, fft_unique_bins]` `Tensor` of STFT bins
- representing a batch of `fft_length`-point STFTs.
+ representing a batch of `fft_length`-point STFTs where `fft_unique_bins`
+ is `fft_length // 2 + 1`
frame_length: An integer scalar `Tensor`. The window length in samples.
frame_step: An integer scalar `Tensor`. The number of samples to step.
fft_length: An integer scalar `Tensor`. The size of the FFT that produced
@@ -133,6 +131,8 @@ def inverse_stft(stfts,
Raises:
ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar,
`frame_step` is not scalar, or `fft_length` is not scalar.
+
+ [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
"""
with ops.name_scope(name, 'inverse_stft', [stfts]):
stfts = ops.convert_to_tensor(stfts, name='stfts')
diff --git a/tensorflow/contrib/signal/python/ops/window_ops.py b/tensorflow/contrib/signal/python/ops/window_ops.py
index 2179484ebc..07a847dd2a 100644
--- a/tensorflow/contrib/signal/python/ops/window_ops.py
+++ b/tensorflow/contrib/signal/python/ops/window_ops.py
@@ -29,9 +29,7 @@ from tensorflow.python.ops import math_ops
def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
- """Generate a Hann window.
-
- https://en.wikipedia.org/wiki/Window_function#Hann_window
+ """Generate a [Hann window][hann].
Args:
window_length: A scalar `Tensor` indicating the window length to generate.
@@ -47,6 +45,8 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
Raises:
ValueError: If `dtype` is not a floating point type.
+
+ [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window
"""
return _raised_cosine_window(name, 'hann_window', window_length, periodic,
dtype, 0.5, 0.5)
@@ -54,9 +54,7 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
name=None):
- """Generate a Hamming window.
-
- https://en.wikipedia.org/wiki/Window_function#Hamming_window
+ """Generate a [Hamming][hamming] window.
Args:
window_length: A scalar `Tensor` indicating the window length to generate.
@@ -72,6 +70,8 @@ def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
Raises:
ValueError: If `dtype` is not a floating point type.
+
+ [hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window
"""
return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
dtype, 0.54, 0.46)