aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-14 13:25:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-14 13:30:01 -0700
commit5051e9743d0482ec6bb0f4681de2c529f99f04e4 (patch)
tree1590c3dbef9a38f12869b8f569c9d405413d1e02 /tensorflow/contrib/signal
parent7798a75cce88e07e66d141c9d3ecccf5ae28a0ef (diff)
Add new features to tf.contrib.signal.frames:
- Support arbitrary rank tensors. - Add an `axis` parameter for framing any axis. - Add a `pad_end` and `pad_value` parameter for controlling the signal padding. Padding is disabled by default to avoid a copy of a potentially large Tensor. - Support shape inference of resulting framed tensor. - Expand the tests, including tests for the gradients. - Since `frames` is a noun and not a verb, rename `frames` to `frame`. PiperOrigin-RevId: 161998921
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/BUILD1
-rw-r--r--tensorflow/contrib/signal/__init__.py7
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py288
-rw-r--r--tensorflow/contrib/signal/python/ops/shape_ops.py189
-rw-r--r--tensorflow/contrib/signal/python/ops/util_ops.py57
5 files changed, 488 insertions, 54 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 7fa05bf39b..6ed2313dca 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -29,6 +29,7 @@ cuda_py_tests(
":signal_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index b4be2f7b4c..780a090971 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""##Signal ops.
-@@frames
+@@frame
@@hamming_window
@@hann_window
"""
@@ -23,7 +23,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.signal.python.ops.shape_ops import frames
+from tensorflow.contrib.signal.python.ops.shape_ops import frame
+# `frame` used to be named `frames`, which is a noun and not a verb.
+# Keep an alias to `frames` for backwards compatibility.
+from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames
from tensorflow.contrib.signal.python.ops.window_ops import hamming_window
from tensorflow.contrib.signal.python.ops.window_ops import hann_window
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
index e07942875f..8633ced599 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -24,18 +24,18 @@ from tensorflow.contrib.signal.python.ops import shape_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class FramesTest(test.TestCase):
+class FrameTest(test.TestCase):
def test_mapping_of_indices_without_padding(self):
- with self.test_session():
+ with self.test_session(use_gpu=True):
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
tensor = array_ops.expand_dims(tensor, 0)
- result = shape_ops.frames(tensor, 512, 180)
- result = result.eval()
+ result = shape_ops.frame(tensor, 512, 180, pad_end=False).eval()
expected = np.tile(np.arange(512), (49, 1))
expected += np.tile(np.arange(49) * 180, (512, 1)).T
@@ -46,15 +46,14 @@ class FramesTest(test.TestCase):
self.assertAllEqual(expected, result)
def test_mapping_of_indices_with_padding(self):
- with self.test_session():
+ with self.test_session(use_gpu=True):
tensor = constant_op.constant(np.arange(10000), dtypes.int32)
tensor = array_ops.expand_dims(tensor, 0)
- result = shape_ops.frames(tensor, 512, 192)
- result = result.eval()
+ result = shape_ops.frame(tensor, 512, 192, pad_end=True).eval()
- expected = np.tile(np.arange(512), (51, 1))
- expected += np.tile(np.arange(51) * 192, (512, 1)).T
+ expected = np.tile(np.arange(512), (53, 1))
+ expected += np.tile(np.arange(53) * 192, (512, 1)).T
expected[expected >= 10000] = 0
@@ -63,6 +62,277 @@ class FramesTest(test.TestCase):
self.assertAllEqual(expected, result)
+ def test_invalid_inputs(self):
+ # Rank 0 input signal.
+ with self.assertRaises(ValueError):
+ shape_ops.frame(1, 1, 1)
+
+ # If the rank is unknown, do not raise an exception.
+ shape_ops.frame(array_ops.placeholder(dtypes.float32), 1, 1)
+
+ # Non-scalar frame_length.
+ with self.assertRaises(ValueError):
+ shape_ops.frame([1], [1], 1)
+
+ # Non-scalar frame_step.
+ with self.assertRaises(ValueError):
+ shape_ops.frame([1], 1, [1])
+
+ # Non-scalar pad_value.
+ with self.assertRaises(ValueError):
+ shape_ops.frame([1], 1, 1, pad_end=True, pad_value=[1])
+
+ def test_length_zero(self):
+ signal = constant_op.constant([], dtype=dtypes.float32)
+ frame_length = 2
+ frame_step = 1
+
+ with self.test_session(use_gpu=True):
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=True, pad_value=99).eval()
+ self.assertEqual((0, 2), result.shape)
+
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=False).eval()
+ self.assertEqual((0, 2), result.shape)
+
+ def test_shape_inference(self):
+ signal = array_ops.placeholder(dtypes.int32, shape=[1, 1])
+ frame_length = 2
+ frame_step = 1
+ # Shape inference is able to detect the rank and inner-most dimension
+ # if frame_length is known at graph definition time.
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=True, pad_value=99)
+ self.assertEqual([1, 1, 2], result.shape.as_list())
+
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=False)
+ self.assertEqual([1, 0, 2], result.shape.as_list())
+
+ # If frame_length is not known, rank and (known) outer and inner dimensions
+ # are inferred.
+ signal = array_ops.placeholder(dtypes.int32, shape=[1, 2, 3, 4])
+ frame_length = array_ops.placeholder(dtypes.int32, shape=[])
+ frame_step = 1
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=True, pad_value=99, axis=1)
+ self.assertEqual([1, None, None, 3, 4], result.shape.as_list())
+
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=False, axis=1)
+ self.assertEqual([1, None, None, 3, 4], result.shape.as_list())
+
+ # If frame_length and inner-most dimension is known, rank, inner dimensions,
+ # and known outer dimensions are inferred.
+ signal = array_ops.placeholder(dtypes.int32,
+ shape=[None, 5, None, 20, 5, 3])
+ frame_length = 4
+ frame_step = 3
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=True, pad_value=99, axis=3)
+ self.assertEqual([None, 5, None, 7, 4, 5, 3], result.shape.as_list())
+
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=False, axis=3)
+ self.assertEqual([None, 5, None, 6, 4, 5, 3], result.shape.as_list())
+
+ # Test that shape inference is consistent with actual returned shapes for
+ # small values of signal_length, frame_length, frame_step, and pad_end in
+ # [True, False].
+ frame_step = 1
+ for signal_length in range(2):
+ signal = [0] * signal_length
+ for frame_length in range(2):
+ for pad_end in [False, True]:
+ op = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=pad_end, pad_value=99)
+ with self.test_session(use_gpu=True):
+ result = op.eval()
+ self.assertEqual(op.shape.as_list(), list(result.shape))
+
+ def test_basic_mono(self):
+ signal = np.arange(6)
+ frame_length = 3
+ frame_step = 2
+
+ with self.test_session(use_gpu=True):
+ for rank in range(5):
+ nd_signal = np.reshape(signal, (1,) * rank + signal.shape)
+
+ # With padding, we pad the last frame with pad_value.
+ result = shape_ops.frame(nd_signal, frame_length, frame_step,
+ pad_end=True, pad_value=99).eval()
+ expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4], [4, 5, 99]])
+ expected = np.reshape(
+ expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
+ self.assertAllEqual(expected, result)
+
+ # Without padding, we drop the last frame.
+ expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4]])
+ expected = np.reshape(
+ expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
+ result = shape_ops.frame(nd_signal, frame_length, frame_step,
+ pad_end=False).eval()
+ self.assertAllEqual(expected, result)
+
+ def test_basic_stereo(self):
+ signal = np.vstack([np.arange(6),
+ np.arange(6) + 10])
+ frame_length = 3
+ frame_step = 2
+
+ with self.test_session(use_gpu=True):
+ for rank in range(5):
+ nd_signal = np.reshape(signal, (1,) * rank + signal.shape)
+
+ # With padding, we pad the last frame with pad_value.
+ result = shape_ops.frame(nd_signal, frame_length, frame_step,
+ pad_end=True, pad_value=99).eval()
+ expected_inner_frames = np.array([
+ [[0, 1, 2], [2, 3, 4], [4, 5, 99]],
+ [[10, 11, 12], [12, 13, 14], [14, 15, 99]]])
+ expected = np.reshape(
+ expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
+ self.assertAllEqual(expected, result)
+
+ # Without padding, we drop the last frame.
+ expected_inner_frames = np.array([[[0, 1, 2], [2, 3, 4]],
+ [[10, 11, 12], [12, 13, 14]]])
+ expected = np.reshape(
+ expected_inner_frames, (1,) * rank + expected_inner_frames.shape)
+ result = shape_ops.frame(nd_signal, frame_length, frame_step,
+ pad_end=False).eval()
+ self.assertAllEqual(expected, result)
+
+ def test_complex_shape(self):
+ signal = np.vstack([np.arange(6),
+ np.arange(6) + 10,
+ np.arange(6) + 20,
+ np.arange(6) + 30,
+ np.arange(6) + 40,
+ np.arange(6) + 50])
+ signal = np.reshape(signal, (2, 1, 3, 1, 6))
+ frame_length = 3
+ frame_step = 2
+
+ with self.test_session(use_gpu=True):
+ # With padding, we pad the last frame with pad_value.
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=True, pad_value=99).eval()
+ # Resulting shape is (2, 1, 3, 1, 3, 3).
+ expected = [[[[[[0, 1, 2], [2, 3, 4], [4, 5, 99]]],
+ [[[10, 11, 12], [12, 13, 14], [14, 15, 99]]],
+ [[[20, 21, 22], [22, 23, 24], [24, 25, 99]]]]],
+ [[[[[30, 31, 32], [32, 33, 34], [34, 35, 99]]],
+ [[[40, 41, 42], [42, 43, 44], [44, 45, 99]]],
+ [[[50, 51, 52], [52, 53, 54], [54, 55, 99]]]]]]
+ self.assertAllEqual(expected, result)
+
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=False).eval()
+ # Resulting shape is (2, 1, 3, 1, 3, 2).
+ expected = [[[[[[0, 1, 2], [2, 3, 4]]],
+ [[[10, 11, 12], [12, 13, 14]]],
+ [[[20, 21, 22], [22, 23, 24]]]]],
+ [[[[[30, 31, 32], [32, 33, 34]]],
+ [[[40, 41, 42], [42, 43, 44]]],
+ [[[50, 51, 52], [52, 53, 54]]]]]]
+ self.assertAllEqual(expected, result)
+
+ def test_axis(self):
+ signal = np.reshape(np.arange(16), (2, 4, 2))
+ with self.test_session(use_gpu=True):
+ result = shape_ops.frame(signal, frame_length=2, frame_step=2,
+ pad_end=True, axis=1)
+ expected = np.reshape(np.arange(16), (2, 2, 2, 2))
+ self.assertAllEqual(expected, result.eval())
+
+ result = shape_ops.frame(signal, frame_length=2, frame_step=1,
+ pad_end=True, axis=1)
+ expected = [[[[0, 1], [2, 3]],
+ [[2, 3], [4, 5]],
+ [[4, 5], [6, 7]],
+ [[6, 7], [0, 0]]],
+ [[[8, 9], [10, 11]],
+ [[10, 11], [12, 13]],
+ [[12, 13], [14, 15]],
+ [[14, 15], [0, 0]]]]
+ self.assertAllEqual(expected, result.eval())
+
+ result = shape_ops.frame(signal, frame_length=3, frame_step=1,
+ pad_end=True, axis=1)
+ expected = [[[[0, 1], [2, 3], [4, 5]],
+ [[2, 3], [4, 5], [6, 7]],
+ [[4, 5], [6, 7], [0, 0]],
+ [[6, 7], [0, 0], [0, 0]]],
+ [[[8, 9], [10, 11], [12, 13]],
+ [[10, 11], [12, 13], [14, 15]],
+ [[12, 13], [14, 15], [0, 0]],
+ [[14, 15], [0, 0], [0, 0]]]]
+ self.assertAllEqual(expected, result.eval())
+
+ def test_window_larger_than_signal(self):
+ signal = constant_op.constant([[1, 2], [11, 12]], dtype=dtypes.float32)
+ frame_length = 4
+ frame_step = 1
+
+ with self.test_session(use_gpu=True):
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=True, pad_value=99).eval()
+ self.assertAllClose([[[1, 2, 99, 99], [2, 99, 99, 99]],
+ [[11, 12, 99, 99], [12, 99, 99, 99]]], result)
+
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=False).eval()
+ self.assertEqual((2, 0, 4), result.shape)
+
+ frame_step = 2
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=True, pad_value=99).eval()
+ self.assertAllClose([[[1, 2, 99, 99]], [[11, 12, 99, 99]]], result)
+
+ result = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=False).eval()
+ self.assertEqual((2, 0, 4), result.shape)
+
+ def test_preserves_type(self):
+ signal = math_ops.range(10, dtype=dtypes.float64)
+ frame_length = 2
+ frame_step = 3
+
+ with self.test_session(use_gpu=True):
+ result = shape_ops.frame(signal, frame_length, frame_step)
+ self.assertEqual(result.dtype, signal.dtype)
+
+ def test_dynamic_tensor(self):
+ # Show that frame works even when the dimensions of its input are
+ # not known at graph creation time.
+ input_signal = np.vstack([np.arange(4), np.arange(4) + 10,
+ np.arange(4) + 20])
+ frame_length = 2
+ frame_step = 2
+
+ with self.test_session(use_gpu=True) as sess:
+ signal_placeholder = array_ops.placeholder(shape=(None, None),
+ dtype=dtypes.float32)
+ result = sess.run(shape_ops.frame(
+ signal_placeholder, frame_length, frame_step),
+ feed_dict={signal_placeholder: input_signal})
+ self.assertAllEqual([[[0, 1], [2, 3]],
+ [[10, 11], [12, 13]],
+ [[20, 21], [22, 23]]], result)
+
+ def test_gradient_numerical(self):
+ with self.test_session(use_gpu=True):
+ signal_shape = (2, 128)
+ signal = array_ops.ones(signal_shape)
+ frame_length = 33
+ frame_step = 9
+ frames = shape_ops.frame(signal, frame_length, frame_step)
+ error = test.compute_gradient_error(
+ signal, signal_shape, frames, frames.shape.as_list())
+ self.assertLess(error, 2e-5)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py
index 4914f19be7..dc7a073242 100644
--- a/tensorflow/contrib/signal/python/ops/shape_ops.py
+++ b/tensorflow/contrib/signal/python/ops/shape_ops.py
@@ -18,70 +18,173 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import dtypes
+
+from tensorflow.contrib.signal.python.ops import util_ops
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
-def frames(signal, frame_length, frame_step, name=None):
- """Frame a signal into overlapping frames.
-
- May be used in front of spectral functions.
+def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
+ """Infers the shape of the return value of `frame`."""
+ frame_length = tensor_util.constant_value(frame_length)
+ frame_step = tensor_util.constant_value(frame_step)
+ axis = tensor_util.constant_value(axis)
+ if signal.shape.ndims is None:
+ return None
+ if axis is None:
+ return [None] * (signal.shape.ndims + 1)
+
+ signal_shape = signal.shape.as_list()
+ num_frames = None
+ frame_axis = signal_shape[axis]
+ outer_dimensions = signal_shape[:axis]
+ inner_dimensions = signal_shape[axis:][1:]
+ if signal_shape and frame_axis is not None:
+ if frame_step and frame_length is not None:
+ if pad_end:
+ # Double negative is so that we round up.
+ num_frames = -(-frame_axis // frame_step)
+ else:
+ num_frames = (frame_axis - frame_length + frame_step) // frame_step
+ num_frames = max(0, num_frames)
+ return outer_dimensions + [num_frames, frame_length] + inner_dimensions
+
+
+def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
+ name=None):
+ """Expands `signal`'s `axis` dimension into frames of `frame_length`.
+
+ Slides a window of size `frame_length` over `signal`s `axis` dimension
+ with a stride of `frame_step`, replacing the `axis` dimension with
+ `[frames, frame_length]` frames.
+
+ If `pad_end` is True, window positions that are past the end of the `axis`
+ dimension are padded with `pad_value` until the window moves fully past the
+ end of the dimension. Otherwise, only window positions that fully overlap the
+ `axis` dimension are produced.
For example:
```python
pcm = tf.placeholder(tf.float32, [None, 9152])
- frames = tf.contrib.signal.frames(pcm, 512, 180)
+ frames = tf.contrib.signal.frame(pcm, 512, 180)
magspec = tf.abs(tf.spectral.rfft(frames, [512]))
image = tf.expand_dims(magspec, 3)
```
Args:
- signal: A `Tensor` of shape `[batch_size, signal_length]`.
- frame_length: An `int32` or `int64` `Tensor`. The length of each frame.
- frame_step: An `int32` or `int64` `Tensor`. The step between frames.
- name: A name for the operation (optional).
+ signal: A `[..., samples, ...]` `Tensor`. The rank and dimensions
+ may be unknown. Rank must be at least 1.
+ frame_length: The frame length in samples. An integer or scalar `Tensor`.
+ frame_step: The frame hop size in samples. An integer or scalar `Tensor`.
+ pad_end: Whether to pad the end of `signal` with `pad_value`.
+ pad_value: An optional scalar `Tensor` to use where the input signal
+ does not exist when `pad_end` is True.
+ axis: A scalar integer `Tensor` indicating the axis to frame. Defaults to
+ the last axis. Supports negative values for indexing from the end.
+ name: An optional name for the operation.
Returns:
- A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
+ A `Tensor` of frames with shape `[..., frames, frame_length, ...]`.
Raises:
- ValueError: if signal does not have rank 2.
+ ValueError: If `frame_length`, `frame_step`, or `pad_value` are not scalar.
"""
- with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
+ with ops.name_scope(name, "frame", [signal, frame_length, frame_step,
+ pad_value]):
signal = ops.convert_to_tensor(signal, name="signal")
frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
-
- signal_rank = signal.shape.ndims
-
- if signal_rank != 2:
- raise ValueError("expected signal to have rank 2 but was " + signal_rank)
-
- signal_length = array_ops.shape(signal)[1]
-
- num_frames = math_ops.ceil((signal_length - frame_length) / frame_step)
- num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
-
- pad_length = (num_frames - 1) * frame_step + frame_length
- pad_signal = array_ops.pad(signal, [[0, 0], [0,
- pad_length - signal_length]])
-
- indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
- indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
-
- indices_step = array_ops.expand_dims(
- math_ops.range(num_frames) * frame_step, 1)
- indices_steps = array_ops.tile(indices_step, [1, frame_length])
-
- indices = indices_frames + indices_steps
-
- # TODO(androbin): remove `transpose` when `gather` gets `axis` support
- pad_signal = array_ops.transpose(pad_signal)
- signal_frames = array_ops.gather(pad_signal, indices)
- signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1])
-
- return signal_frames
+ axis = ops.convert_to_tensor(axis, name="axis")
+
+ signal.shape.with_rank_at_least(1)
+ frame_length.shape.assert_has_rank(0)
+ frame_step.shape.assert_has_rank(0)
+ axis.shape.assert_has_rank(0)
+
+ result_shape = _infer_frame_shape(signal, frame_length, frame_step, pad_end,
+ axis)
+
+ # Axis can be negative. Convert it to positive.
+ signal_rank = array_ops.rank(signal)
+ axis = math_ops.range(signal_rank)[axis]
+
+ signal_shape = array_ops.shape(signal)
+ outer_dimensions, length_samples, inner_dimensions = array_ops.split(
+ signal_shape, [axis, 1, signal_rank - 1 - axis])
+ length_samples = array_ops.reshape(length_samples, [])
+ num_outer_dimensions = array_ops.size(outer_dimensions)
+ num_inner_dimensions = array_ops.size(inner_dimensions)
+
+ # If padding is requested, pad the input signal tensor with pad_value.
+ if pad_end:
+ pad_value = ops.convert_to_tensor(pad_value, signal.dtype)
+ pad_value.shape.assert_has_rank(0)
+
+ # Calculate number of frames, using double negatives to round up.
+ num_frames = -(-length_samples // frame_step)
+
+ # Pad the signal by up to frame_length samples based on how many samples
+ # are remaining starting from last_frame_position.
+ pad_samples = math_ops.maximum(
+ 0, frame_length + frame_step * (num_frames - 1) - length_samples)
+
+ # Pad the inner dimension of signal by pad_samples.
+ paddings = array_ops.concat(
+ [array_ops.zeros([num_outer_dimensions, 2], dtype=pad_samples.dtype),
+ [[0, pad_samples]],
+ array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)],
+ 0)
+ signal = array_ops.pad(signal, paddings, constant_values=pad_value)
+
+ signal_shape = array_ops.shape(signal)
+ length_samples = signal_shape[axis]
+ else:
+ num_frames = math_ops.maximum(
+ 0, 1 + (length_samples - frame_length) // frame_step)
+
+ subframe_length = util_ops.gcd(frame_length, frame_step)
+ subframes_per_frame = frame_length // subframe_length
+ subframes_per_hop = frame_step // subframe_length
+ num_subframes = length_samples // subframe_length
+
+ slice_shape = array_ops.concat([outer_dimensions,
+ [num_subframes * subframe_length],
+ inner_dimensions], 0)
+ subframe_shape = array_ops.concat([outer_dimensions,
+ [num_subframes, subframe_length],
+ inner_dimensions], 0)
+ subframes = array_ops.reshape(array_ops.strided_slice(
+ signal, array_ops.zeros_like(signal_shape),
+ slice_shape), subframe_shape)
+
+ # frame_selector is a [num_frames, subframes_per_frame] tensor
+ # that indexes into the appropriate frame in subframes. For example:
+ # [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]]
+ frame_selector = array_ops.reshape(
+ math_ops.range(num_frames) * subframes_per_hop, [num_frames, 1])
+
+ # subframe_selector is a [num_frames, subframes_per_frame] tensor
+ # that indexes into the appropriate subframe within a frame. For example:
+ # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]
+ subframe_selector = array_ops.reshape(
+ math_ops.range(subframes_per_frame), [1, subframes_per_frame])
+
+ # Adding the 2 selector tensors together produces a [num_frames,
+ # subframes_per_frame] tensor of indices to use with tf.gather to select
+ # subframes from subframes. We then reshape the inner-most
+ # subframes_per_frame dimension to stitch the subframes together into
+ # frames. For example: [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]].
+ selector = frame_selector + subframe_selector
+
+ frames = array_ops.reshape(
+ array_ops.gather(subframes, selector, axis=axis),
+ array_ops.concat([outer_dimensions, [num_frames, frame_length],
+ inner_dimensions], 0))
+
+ if result_shape:
+ frames.set_shape(result_shape)
+ return frames
diff --git a/tensorflow/contrib/signal/python/ops/util_ops.py b/tensorflow/contrib/signal/python/ops/util_ops.py
new file mode 100644
index 0000000000..eee829d799
--- /dev/null
+++ b/tensorflow/contrib/signal/python/ops/util_ops.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility ops shared across tf.contrib.signal."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+def gcd(a, b, name=None):
+ """Returns the greatest common divisor via Euclid's algorithm.
+
+ Args:
+ a: The dividend. A scalar integer `Tensor`.
+ b: The divisor. A scalar integer `Tensor`.
+ name: An optional name for the operation.
+
+ Returns:
+ A scalar `Tensor` representing the greatest common divisor between `a` and
+ `b`.
+
+ Raises:
+ ValueError: If `a` or `b` are not scalar integers.
+ """
+ with ops.name_scope(name, 'gcd', [a, b]):
+ a = ops.convert_to_tensor(a)
+ b = ops.convert_to_tensor(b)
+
+ a.shape.assert_has_rank(0)
+ b.shape.assert_has_rank(0)
+
+ if not a.dtype.is_integer:
+ raise ValueError('a must be an integer type. Got: %s' % a.dtype)
+ if not b.dtype.is_integer:
+ raise ValueError('b must be an integer type. Got: %s' % b.dtype)
+
+ cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b))
+ body = lambda a, b: [b, math_ops.mod(a, b)]
+ a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False)
+ return a