diff options
author | RJ Ryan <rjryan@google.com> | 2017-07-14 13:25:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-14 13:30:01 -0700 |
commit | 5051e9743d0482ec6bb0f4681de2c529f99f04e4 (patch) | |
tree | 1590c3dbef9a38f12869b8f569c9d405413d1e02 /tensorflow/contrib/signal | |
parent | 7798a75cce88e07e66d141c9d3ecccf5ae28a0ef (diff) |
Add new features to tf.contrib.signal.frames:
- Support arbitrary rank tensors.
- Add an `axis` parameter for framing any axis.
- Add a `pad_end` and `pad_value` parameter for controlling the signal padding. Padding is disabled by default to avoid a copy of a potentially large Tensor.
- Support shape inference of resulting framed tensor.
- Expand the tests, including tests for the gradients.
- Since `frames` is a noun and not a verb, rename `frames` to `frame`.
PiperOrigin-RevId: 161998921
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/signal/__init__.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py | 288 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/shape_ops.py | 189 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/util_ops.py | 57 |
5 files changed, 488 insertions, 54 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 7fa05bf39b..6ed2313dca 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -29,6 +29,7 @@ cuda_py_tests( ":signal_py", "//third_party/py/numpy", "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index b4be2f7b4c..780a090971 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """##Signal ops. -@@frames +@@frame @@hamming_window @@hann_window """ @@ -23,7 +23,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.signal.python.ops.shape_ops import frames +from tensorflow.contrib.signal.python.ops.shape_ops import frame +# `frame` used to be named `frames`, which is a noun and not a verb. +# Keep an alias to `frames` for backwards compatibility. +from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames from tensorflow.contrib.signal.python.ops.window_ops import hamming_window from tensorflow.contrib.signal.python.ops.window_ops import hann_window diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py index e07942875f..8633ced599 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py @@ -24,18 +24,18 @@ from tensorflow.contrib.signal.python.ops import shape_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class FramesTest(test.TestCase): +class FrameTest(test.TestCase): def test_mapping_of_indices_without_padding(self): - with self.test_session(): + with self.test_session(use_gpu=True): tensor = constant_op.constant(np.arange(9152), dtypes.int32) tensor = array_ops.expand_dims(tensor, 0) - result = shape_ops.frames(tensor, 512, 180) - result = result.eval() + result = shape_ops.frame(tensor, 512, 180, pad_end=False).eval() expected = np.tile(np.arange(512), (49, 1)) expected += np.tile(np.arange(49) * 180, (512, 1)).T @@ -46,15 +46,14 @@ class FramesTest(test.TestCase): self.assertAllEqual(expected, result) def test_mapping_of_indices_with_padding(self): - with self.test_session(): + with self.test_session(use_gpu=True): tensor = constant_op.constant(np.arange(10000), dtypes.int32) tensor = array_ops.expand_dims(tensor, 0) - result = shape_ops.frames(tensor, 512, 192) - result = result.eval() + result = shape_ops.frame(tensor, 512, 192, pad_end=True).eval() - expected = np.tile(np.arange(512), (51, 1)) - expected += np.tile(np.arange(51) * 192, (512, 1)).T + expected = np.tile(np.arange(512), (53, 1)) + expected += np.tile(np.arange(53) * 192, (512, 1)).T expected[expected >= 10000] = 0 @@ -63,6 +62,277 @@ class FramesTest(test.TestCase): self.assertAllEqual(expected, result) + def test_invalid_inputs(self): + # Rank 0 input signal. + with self.assertRaises(ValueError): + shape_ops.frame(1, 1, 1) + + # If the rank is unknown, do not raise an exception. + shape_ops.frame(array_ops.placeholder(dtypes.float32), 1, 1) + + # Non-scalar frame_length. + with self.assertRaises(ValueError): + shape_ops.frame([1], [1], 1) + + # Non-scalar frame_step. + with self.assertRaises(ValueError): + shape_ops.frame([1], 1, [1]) + + # Non-scalar pad_value. + with self.assertRaises(ValueError): + shape_ops.frame([1], 1, 1, pad_end=True, pad_value=[1]) + + def test_length_zero(self): + signal = constant_op.constant([], dtype=dtypes.float32) + frame_length = 2 + frame_step = 1 + + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + self.assertEqual((0, 2), result.shape) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + self.assertEqual((0, 2), result.shape) + + def test_shape_inference(self): + signal = array_ops.placeholder(dtypes.int32, shape=[1, 1]) + frame_length = 2 + frame_step = 1 + # Shape inference is able to detect the rank and inner-most dimension + # if frame_length is known at graph definition time. + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99) + self.assertEqual([1, 1, 2], result.shape.as_list()) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False) + self.assertEqual([1, 0, 2], result.shape.as_list()) + + # If frame_length is not known, rank and (known) outer and inner dimensions + # are inferred. + signal = array_ops.placeholder(dtypes.int32, shape=[1, 2, 3, 4]) + frame_length = array_ops.placeholder(dtypes.int32, shape=[]) + frame_step = 1 + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99, axis=1) + self.assertEqual([1, None, None, 3, 4], result.shape.as_list()) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False, axis=1) + self.assertEqual([1, None, None, 3, 4], result.shape.as_list()) + + # If frame_length and inner-most dimension is known, rank, inner dimensions, + # and known outer dimensions are inferred. + signal = array_ops.placeholder(dtypes.int32, + shape=[None, 5, None, 20, 5, 3]) + frame_length = 4 + frame_step = 3 + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99, axis=3) + self.assertEqual([None, 5, None, 7, 4, 5, 3], result.shape.as_list()) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False, axis=3) + self.assertEqual([None, 5, None, 6, 4, 5, 3], result.shape.as_list()) + + # Test that shape inference is consistent with actual returned shapes for + # small values of signal_length, frame_length, frame_step, and pad_end in + # [True, False]. + frame_step = 1 + for signal_length in range(2): + signal = [0] * signal_length + for frame_length in range(2): + for pad_end in [False, True]: + op = shape_ops.frame(signal, frame_length, frame_step, + pad_end=pad_end, pad_value=99) + with self.test_session(use_gpu=True): + result = op.eval() + self.assertEqual(op.shape.as_list(), list(result.shape)) + + def test_basic_mono(self): + signal = np.arange(6) + frame_length = 3 + frame_step = 2 + + with self.test_session(use_gpu=True): + for rank in range(5): + nd_signal = np.reshape(signal, (1,) * rank + signal.shape) + + # With padding, we pad the last frame with pad_value. + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4], [4, 5, 99]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + self.assertAllEqual(expected, result) + + # Without padding, we drop the last frame. + expected_inner_frames = np.array([[0, 1, 2], [2, 3, 4]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=False).eval() + self.assertAllEqual(expected, result) + + def test_basic_stereo(self): + signal = np.vstack([np.arange(6), + np.arange(6) + 10]) + frame_length = 3 + frame_step = 2 + + with self.test_session(use_gpu=True): + for rank in range(5): + nd_signal = np.reshape(signal, (1,) * rank + signal.shape) + + # With padding, we pad the last frame with pad_value. + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + expected_inner_frames = np.array([ + [[0, 1, 2], [2, 3, 4], [4, 5, 99]], + [[10, 11, 12], [12, 13, 14], [14, 15, 99]]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + self.assertAllEqual(expected, result) + + # Without padding, we drop the last frame. + expected_inner_frames = np.array([[[0, 1, 2], [2, 3, 4]], + [[10, 11, 12], [12, 13, 14]]]) + expected = np.reshape( + expected_inner_frames, (1,) * rank + expected_inner_frames.shape) + result = shape_ops.frame(nd_signal, frame_length, frame_step, + pad_end=False).eval() + self.assertAllEqual(expected, result) + + def test_complex_shape(self): + signal = np.vstack([np.arange(6), + np.arange(6) + 10, + np.arange(6) + 20, + np.arange(6) + 30, + np.arange(6) + 40, + np.arange(6) + 50]) + signal = np.reshape(signal, (2, 1, 3, 1, 6)) + frame_length = 3 + frame_step = 2 + + with self.test_session(use_gpu=True): + # With padding, we pad the last frame with pad_value. + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + # Resulting shape is (2, 1, 3, 1, 3, 3). + expected = [[[[[[0, 1, 2], [2, 3, 4], [4, 5, 99]]], + [[[10, 11, 12], [12, 13, 14], [14, 15, 99]]], + [[[20, 21, 22], [22, 23, 24], [24, 25, 99]]]]], + [[[[[30, 31, 32], [32, 33, 34], [34, 35, 99]]], + [[[40, 41, 42], [42, 43, 44], [44, 45, 99]]], + [[[50, 51, 52], [52, 53, 54], [54, 55, 99]]]]]] + self.assertAllEqual(expected, result) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + # Resulting shape is (2, 1, 3, 1, 3, 2). + expected = [[[[[[0, 1, 2], [2, 3, 4]]], + [[[10, 11, 12], [12, 13, 14]]], + [[[20, 21, 22], [22, 23, 24]]]]], + [[[[[30, 31, 32], [32, 33, 34]]], + [[[40, 41, 42], [42, 43, 44]]], + [[[50, 51, 52], [52, 53, 54]]]]]] + self.assertAllEqual(expected, result) + + def test_axis(self): + signal = np.reshape(np.arange(16), (2, 4, 2)) + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length=2, frame_step=2, + pad_end=True, axis=1) + expected = np.reshape(np.arange(16), (2, 2, 2, 2)) + self.assertAllEqual(expected, result.eval()) + + result = shape_ops.frame(signal, frame_length=2, frame_step=1, + pad_end=True, axis=1) + expected = [[[[0, 1], [2, 3]], + [[2, 3], [4, 5]], + [[4, 5], [6, 7]], + [[6, 7], [0, 0]]], + [[[8, 9], [10, 11]], + [[10, 11], [12, 13]], + [[12, 13], [14, 15]], + [[14, 15], [0, 0]]]] + self.assertAllEqual(expected, result.eval()) + + result = shape_ops.frame(signal, frame_length=3, frame_step=1, + pad_end=True, axis=1) + expected = [[[[0, 1], [2, 3], [4, 5]], + [[2, 3], [4, 5], [6, 7]], + [[4, 5], [6, 7], [0, 0]], + [[6, 7], [0, 0], [0, 0]]], + [[[8, 9], [10, 11], [12, 13]], + [[10, 11], [12, 13], [14, 15]], + [[12, 13], [14, 15], [0, 0]], + [[14, 15], [0, 0], [0, 0]]]] + self.assertAllEqual(expected, result.eval()) + + def test_window_larger_than_signal(self): + signal = constant_op.constant([[1, 2], [11, 12]], dtype=dtypes.float32) + frame_length = 4 + frame_step = 1 + + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + self.assertAllClose([[[1, 2, 99, 99], [2, 99, 99, 99]], + [[11, 12, 99, 99], [12, 99, 99, 99]]], result) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + self.assertEqual((2, 0, 4), result.shape) + + frame_step = 2 + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=True, pad_value=99).eval() + self.assertAllClose([[[1, 2, 99, 99]], [[11, 12, 99, 99]]], result) + + result = shape_ops.frame(signal, frame_length, frame_step, + pad_end=False).eval() + self.assertEqual((2, 0, 4), result.shape) + + def test_preserves_type(self): + signal = math_ops.range(10, dtype=dtypes.float64) + frame_length = 2 + frame_step = 3 + + with self.test_session(use_gpu=True): + result = shape_ops.frame(signal, frame_length, frame_step) + self.assertEqual(result.dtype, signal.dtype) + + def test_dynamic_tensor(self): + # Show that frame works even when the dimensions of its input are + # not known at graph creation time. + input_signal = np.vstack([np.arange(4), np.arange(4) + 10, + np.arange(4) + 20]) + frame_length = 2 + frame_step = 2 + + with self.test_session(use_gpu=True) as sess: + signal_placeholder = array_ops.placeholder(shape=(None, None), + dtype=dtypes.float32) + result = sess.run(shape_ops.frame( + signal_placeholder, frame_length, frame_step), + feed_dict={signal_placeholder: input_signal}) + self.assertAllEqual([[[0, 1], [2, 3]], + [[10, 11], [12, 13]], + [[20, 21], [22, 23]]], result) + + def test_gradient_numerical(self): + with self.test_session(use_gpu=True): + signal_shape = (2, 128) + signal = array_ops.ones(signal_shape) + frame_length = 33 + frame_step = 9 + frames = shape_ops.frame(signal, frame_length, frame_step) + error = test.compute_gradient_error( + signal, signal_shape, frames, frames.shape.as_list()) + self.assertLess(error, 2e-5) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py index 4914f19be7..dc7a073242 100644 --- a/tensorflow/contrib/signal/python/ops/shape_ops.py +++ b/tensorflow/contrib/signal/python/ops/shape_ops.py @@ -18,70 +18,173 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes + +from tensorflow.contrib.signal.python.ops import util_ops from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -def frames(signal, frame_length, frame_step, name=None): - """Frame a signal into overlapping frames. - - May be used in front of spectral functions. +def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis): + """Infers the shape of the return value of `frame`.""" + frame_length = tensor_util.constant_value(frame_length) + frame_step = tensor_util.constant_value(frame_step) + axis = tensor_util.constant_value(axis) + if signal.shape.ndims is None: + return None + if axis is None: + return [None] * (signal.shape.ndims + 1) + + signal_shape = signal.shape.as_list() + num_frames = None + frame_axis = signal_shape[axis] + outer_dimensions = signal_shape[:axis] + inner_dimensions = signal_shape[axis:][1:] + if signal_shape and frame_axis is not None: + if frame_step and frame_length is not None: + if pad_end: + # Double negative is so that we round up. + num_frames = -(-frame_axis // frame_step) + else: + num_frames = (frame_axis - frame_length + frame_step) // frame_step + num_frames = max(0, num_frames) + return outer_dimensions + [num_frames, frame_length] + inner_dimensions + + +def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1, + name=None): + """Expands `signal`'s `axis` dimension into frames of `frame_length`. + + Slides a window of size `frame_length` over `signal`s `axis` dimension + with a stride of `frame_step`, replacing the `axis` dimension with + `[frames, frame_length]` frames. + + If `pad_end` is True, window positions that are past the end of the `axis` + dimension are padded with `pad_value` until the window moves fully past the + end of the dimension. Otherwise, only window positions that fully overlap the + `axis` dimension are produced. For example: ```python pcm = tf.placeholder(tf.float32, [None, 9152]) - frames = tf.contrib.signal.frames(pcm, 512, 180) + frames = tf.contrib.signal.frame(pcm, 512, 180) magspec = tf.abs(tf.spectral.rfft(frames, [512])) image = tf.expand_dims(magspec, 3) ``` Args: - signal: A `Tensor` of shape `[batch_size, signal_length]`. - frame_length: An `int32` or `int64` `Tensor`. The length of each frame. - frame_step: An `int32` or `int64` `Tensor`. The step between frames. - name: A name for the operation (optional). + signal: A `[..., samples, ...]` `Tensor`. The rank and dimensions + may be unknown. Rank must be at least 1. + frame_length: The frame length in samples. An integer or scalar `Tensor`. + frame_step: The frame hop size in samples. An integer or scalar `Tensor`. + pad_end: Whether to pad the end of `signal` with `pad_value`. + pad_value: An optional scalar `Tensor` to use where the input signal + does not exist when `pad_end` is True. + axis: A scalar integer `Tensor` indicating the axis to frame. Defaults to + the last axis. Supports negative values for indexing from the end. + name: An optional name for the operation. Returns: - A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`. + A `Tensor` of frames with shape `[..., frames, frame_length, ...]`. Raises: - ValueError: if signal does not have rank 2. + ValueError: If `frame_length`, `frame_step`, or `pad_value` are not scalar. """ - with ops.name_scope(name, "frames", [signal, frame_length, frame_step]): + with ops.name_scope(name, "frame", [signal, frame_length, frame_step, + pad_value]): signal = ops.convert_to_tensor(signal, name="signal") frame_length = ops.convert_to_tensor(frame_length, name="frame_length") frame_step = ops.convert_to_tensor(frame_step, name="frame_step") - - signal_rank = signal.shape.ndims - - if signal_rank != 2: - raise ValueError("expected signal to have rank 2 but was " + signal_rank) - - signal_length = array_ops.shape(signal)[1] - - num_frames = math_ops.ceil((signal_length - frame_length) / frame_step) - num_frames = 1 + math_ops.cast(num_frames, dtypes.int32) - - pad_length = (num_frames - 1) * frame_step + frame_length - pad_signal = array_ops.pad(signal, [[0, 0], [0, - pad_length - signal_length]]) - - indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0) - indices_frames = array_ops.tile(indices_frame, [num_frames, 1]) - - indices_step = array_ops.expand_dims( - math_ops.range(num_frames) * frame_step, 1) - indices_steps = array_ops.tile(indices_step, [1, frame_length]) - - indices = indices_frames + indices_steps - - # TODO(androbin): remove `transpose` when `gather` gets `axis` support - pad_signal = array_ops.transpose(pad_signal) - signal_frames = array_ops.gather(pad_signal, indices) - signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1]) - - return signal_frames + axis = ops.convert_to_tensor(axis, name="axis") + + signal.shape.with_rank_at_least(1) + frame_length.shape.assert_has_rank(0) + frame_step.shape.assert_has_rank(0) + axis.shape.assert_has_rank(0) + + result_shape = _infer_frame_shape(signal, frame_length, frame_step, pad_end, + axis) + + # Axis can be negative. Convert it to positive. + signal_rank = array_ops.rank(signal) + axis = math_ops.range(signal_rank)[axis] + + signal_shape = array_ops.shape(signal) + outer_dimensions, length_samples, inner_dimensions = array_ops.split( + signal_shape, [axis, 1, signal_rank - 1 - axis]) + length_samples = array_ops.reshape(length_samples, []) + num_outer_dimensions = array_ops.size(outer_dimensions) + num_inner_dimensions = array_ops.size(inner_dimensions) + + # If padding is requested, pad the input signal tensor with pad_value. + if pad_end: + pad_value = ops.convert_to_tensor(pad_value, signal.dtype) + pad_value.shape.assert_has_rank(0) + + # Calculate number of frames, using double negatives to round up. + num_frames = -(-length_samples // frame_step) + + # Pad the signal by up to frame_length samples based on how many samples + # are remaining starting from last_frame_position. + pad_samples = math_ops.maximum( + 0, frame_length + frame_step * (num_frames - 1) - length_samples) + + # Pad the inner dimension of signal by pad_samples. + paddings = array_ops.concat( + [array_ops.zeros([num_outer_dimensions, 2], dtype=pad_samples.dtype), + [[0, pad_samples]], + array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)], + 0) + signal = array_ops.pad(signal, paddings, constant_values=pad_value) + + signal_shape = array_ops.shape(signal) + length_samples = signal_shape[axis] + else: + num_frames = math_ops.maximum( + 0, 1 + (length_samples - frame_length) // frame_step) + + subframe_length = util_ops.gcd(frame_length, frame_step) + subframes_per_frame = frame_length // subframe_length + subframes_per_hop = frame_step // subframe_length + num_subframes = length_samples // subframe_length + + slice_shape = array_ops.concat([outer_dimensions, + [num_subframes * subframe_length], + inner_dimensions], 0) + subframe_shape = array_ops.concat([outer_dimensions, + [num_subframes, subframe_length], + inner_dimensions], 0) + subframes = array_ops.reshape(array_ops.strided_slice( + signal, array_ops.zeros_like(signal_shape), + slice_shape), subframe_shape) + + # frame_selector is a [num_frames, subframes_per_frame] tensor + # that indexes into the appropriate frame in subframes. For example: + # [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]] + frame_selector = array_ops.reshape( + math_ops.range(num_frames) * subframes_per_hop, [num_frames, 1]) + + # subframe_selector is a [num_frames, subframes_per_frame] tensor + # that indexes into the appropriate subframe within a frame. For example: + # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] + subframe_selector = array_ops.reshape( + math_ops.range(subframes_per_frame), [1, subframes_per_frame]) + + # Adding the 2 selector tensors together produces a [num_frames, + # subframes_per_frame] tensor of indices to use with tf.gather to select + # subframes from subframes. We then reshape the inner-most + # subframes_per_frame dimension to stitch the subframes together into + # frames. For example: [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]]. + selector = frame_selector + subframe_selector + + frames = array_ops.reshape( + array_ops.gather(subframes, selector, axis=axis), + array_ops.concat([outer_dimensions, [num_frames, frame_length], + inner_dimensions], 0)) + + if result_shape: + frames.set_shape(result_shape) + return frames diff --git a/tensorflow/contrib/signal/python/ops/util_ops.py b/tensorflow/contrib/signal/python/ops/util_ops.py new file mode 100644 index 0000000000..eee829d799 --- /dev/null +++ b/tensorflow/contrib/signal/python/ops/util_ops.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility ops shared across tf.contrib.signal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + + +def gcd(a, b, name=None): + """Returns the greatest common divisor via Euclid's algorithm. + + Args: + a: The dividend. A scalar integer `Tensor`. + b: The divisor. A scalar integer `Tensor`. + name: An optional name for the operation. + + Returns: + A scalar `Tensor` representing the greatest common divisor between `a` and + `b`. + + Raises: + ValueError: If `a` or `b` are not scalar integers. + """ + with ops.name_scope(name, 'gcd', [a, b]): + a = ops.convert_to_tensor(a) + b = ops.convert_to_tensor(b) + + a.shape.assert_has_rank(0) + b.shape.assert_has_rank(0) + + if not a.dtype.is_integer: + raise ValueError('a must be an integer type. Got: %s' % a.dtype) + if not b.dtype.is_integer: + raise ValueError('b must be an integer type. Got: %s' % b.dtype) + + cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b)) + body = lambda a, b: [b, math_ops.mod(a, b)] + a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False) + return a |