diff options
author | RJ Ryan <rjryan@google.com> | 2017-07-14 15:32:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-14 15:37:18 -0700 |
commit | ff5b2bf093575b852848b2d9a1fc31684a5800da (patch) | |
tree | 814de599089b07d715424fd31154c9202e8667a4 /tensorflow/contrib/signal | |
parent | cc741b48e8bba90670fe9db3811d50f283ea8b1b (diff) |
Add overlap_and_add to tf.contrib.signal.
This is the inverse operation of tf.contrib.signal.frame. This is implemented using GPU-capable ops and supports gradients.
PiperOrigin-RevId: 162017464
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/BUILD | 18 | ||||
-rw-r--r-- | tensorflow/contrib/signal/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py | 192 | ||||
-rw-r--r-- | tensorflow/contrib/signal/python/ops/reconstruction_ops.py | 144 |
4 files changed, 355 insertions, 1 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 6ed2313dca..706de58a0a 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -22,8 +22,24 @@ py_library( ) cuda_py_tests( + name = "reconstruction_ops_test", + srcs = ["python/kernel_tests/reconstruction_ops_test.py"], + additional_deps = [ + ":signal_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( name = "shape_ops_test", - size = "small", srcs = ["python/kernel_tests/shape_ops_test.py"], additional_deps = [ ":signal_py", diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index 780a090971..d0f6c1d0c6 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -17,12 +17,14 @@ @@frame @@hamming_window @@hann_window +@@overlap_and_add """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add from tensorflow.contrib.signal.python.ops.shape_ops import frame # `frame` used to be named `frames`, which is a noun and not a verb. # Keep an alias to `frames` for backwards compatibility. diff --git a/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py new file mode 100644 index 0000000000..5c9b2ac518 --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py @@ -0,0 +1,192 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for reconstruction_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.signal.python.ops import reconstruction_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ReconstructionOpsTest(test.TestCase): + + def __init__(self, *args, **kwargs): + super(ReconstructionOpsTest, self).__init__(*args, **kwargs) + self.batch_size = 3 + self.frames = 3 + self.samples = 5 + + self.bases = np.array(range(2, 5)) + exponents = np.array(range(self.frames * self.samples)) + powers = np.power(self.bases[:, np.newaxis], exponents[np.newaxis, :]) + + self.powers = np.reshape(powers, [self.batch_size, self.frames, + self.samples]) + self.frame_hop = 2 + + # Hand computed example using powers of unique numbers: this is easily + # verified. + self.expected_string = ["1", "10", "100100", "1001000", "10010010000", + "100100000000", "1001000000000", "10000000000000", + "100000000000000"] + + def test_all_ones(self): + signal = constant_op.constant(np.ones((3, 5)), dtype=dtypes.int64) + reconstruction = reconstruction_ops.overlap_and_add(signal, 2) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + + expected_output = np.array([1, 1, 2, 2, 3, 2, 2, 1, 1]) + + self.assertAllClose(output, expected_output) + + def test_simple(self): + def make_input(frame_length, num_frames=3): + """Generate a tensor of num_frames frames of frame_length.""" + return np.reshape(np.arange(1, num_frames * frame_length + 1), + (-1, frame_length)) + + # List of (signal, expected_result, frame_hop). + configurations = [ + # All hop lengths on a frame length of 2. + (make_input(2), [1, 5, 9, 6], 1), + (make_input(2), [1, 2, 3, 4, 5, 6], 2), + + # All hop lengths on a frame length of 3. + (make_input(3), [1, 6, 15, 14, 9], 1), + (make_input(3), [1, 2, 7, 5, 13, 8, 9], 2), + (make_input(3), [1, 2, 3, 4, 5, 6, 7, 8, 9], 3), + + # All hop lengths on a frame length of 4. + (make_input(4), [1, 7, 18, 21, 19, 12], 1), + (make_input(4), [1, 2, 8, 10, 16, 18, 11, 12], 2), + (make_input(4), [1, 2, 3, 9, 6, 7, 17, 10, 11, 12], 3), + (make_input(4), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4), + ] + + with self.test_session(use_gpu=True): + for signal, expected, frame_hop in configurations: + reconstruction = reconstruction_ops.overlap_and_add( + np.array(signal), frame_hop).eval() + expected_output = np.array(expected) + self.assertAllClose(reconstruction, expected_output) + + def test_powers(self): + signal = constant_op.constant(np.squeeze(self.powers[0, :, :]), + dtype=dtypes.int64) + reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + string_output = [np.base_repr(x, self.bases[0]) for x in output] + + self.assertEqual(string_output, self.expected_string) + + def test_batch(self): + signal = constant_op.constant(self.powers, dtype=dtypes.int64) + reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + + accumulator = True + for i in range(self.batch_size): + string_output = [np.base_repr(x, self.bases[i]) for x in output[i, :]] + accumulator = accumulator and (string_output == self.expected_string) + + self.assertTrue(accumulator) + + def test_one_element_batch(self): + input_matrix = np.squeeze(self.powers[0, :, :]) + input_matrix = input_matrix[np.newaxis, :, :].astype(float) + signal = constant_op.constant(input_matrix, dtype=dtypes.float32) + reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(reconstruction) + + string_output = [np.base_repr(int(x), self.bases[0]) for x in + np.squeeze(output)] + + self.assertEqual(output.shape, (1, 9)) + self.assertEqual(string_output, self.expected_string) + + def test_gradient(self): + configurations = [ + ((1, 128), 1), + ((5, 35), 17), + ((10, 128), 128), + ((2, 10, 128), 127), + ((2, 2, 10, 128), 126), + ((2, 2, 2, 10, 128), 125), + ] + + for shape, frame_hop in configurations: + with self.test_session(use_gpu=True) as sess: + signal = array_ops.zeros(shape) + reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop) + loss = math_ops.reduce_sum(reconstruction) + # Increasing any sample in the input frames by one will increase the sum + # of all the samples in the reconstruction by 1, so the gradient should + # be all ones, no matter the shape or hop. + gradient = sess.run(gradients_impl.gradients([loss], [signal])[0]) + self.assertTrue((gradient == 1.0).all()) + + def test_gradient_batch(self): + with self.test_session(use_gpu=True) as sess: + signal = array_ops.zeros((2, 10, 10)) + frame_hop = 10 + reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop) + + # Multiply the first batch-item's reconstruction by zeros. This will block + # gradient from flowing into the first batch item from the loss. Multiply + # the second batch item by the integers from 0 to 99. Since there is zero + # overlap, the gradient for this batch item will be 0-99 shaped as (10, + # 10). + reconstruction *= array_ops.stack( + [array_ops.zeros((100,)), math_ops.to_float(math_ops.range(100))]) + loss = math_ops.reduce_sum(reconstruction) + + # Verify that only the second batch item receives gradient. + gradient = sess.run(gradients_impl.gradients([loss], [signal])[0]) + expected_gradient = np.stack([ + np.zeros((10, 10)), + np.reshape(np.arange(100).astype(np.float32), (10, 10))]) + self.assertAllEqual(expected_gradient, gradient) + + def test_gradient_numerical(self): + with self.test_session(use_gpu=True): + shape = (2, 10, 10) + framed_signal = array_ops.zeros(shape) + frame_hop = 10 + reconstruction = reconstruction_ops.overlap_and_add( + framed_signal, frame_hop) + error = test.compute_gradient_error( + framed_signal, shape, reconstruction, [2, 100]) + self.assertLess(error, 2e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py new file mode 100644 index 0000000000..f5f443ad09 --- /dev/null +++ b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py @@ -0,0 +1,144 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Signal reconstruction via overlapped addition of frames.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.signal.python.ops import shape_ops +from tensorflow.contrib.signal.python.ops import util_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _shuffle_to_front(input_tensor, k): + """Shuffles the last `k` indices of `input_tensor` to the front. + + Transposes `input_tensor` to have the last `k` indices at the front. The input + may have arbitrary rank and unknown shape. + + Args: + input_tensor: A `Tensor` of arbitrary rank and unknown shape. + k: A scalar `Tensor` specifying how many indices to shuffle. + + Returns: + A tranposed version of `input_tensor` with `k` indices shuffled to the + front. + + Raises: + ValueError: If `input_tensor` is not at least rank `k` or `k` is not scalar. + """ + k = ops.convert_to_tensor(k, name="k") + k.shape.with_rank(0) + k_static = tensor_util.constant_value(k) + if k_static is not None: + input_tensor.shape.with_rank_at_least(k_static) + + rank = array_ops.rank(input_tensor) + outer_indices, inner_indices = array_ops.split(math_ops.range(rank), + [rank - k, k]) + permutation = array_ops.concat([inner_indices, outer_indices], 0) + + return array_ops.transpose(input_tensor, perm=permutation) + + +def overlap_and_add(signal, frame_step, name=None): + """Reconstructs a signal from a framed representation. + + Adds potentially overlapping frames of a signal with shape + `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. + The resulting tensor has shape `[..., output_size]` where + + output_size = (frames - 1) * frame_step + frame_length + + Args: + signal: A [..., frames, frame_length] `Tensor`. All dimensions may be + unknown, and rank must be at least 2. + frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be + less than or equal to `frame_length`. + name: An optional name for the operation. + + Returns: + A `Tensor` with shape `[..., output_size]` containing the overlap-added + frames of `signal`'s inner-most two dimensions. + + Raises: + ValueError: If `signal`'s rank is less than 2, `frame_step` is not a scalar + integer or `frame_step` is greater than `frame_length`. + """ + with ops.name_scope(name, "overlap_and_add", [signal, frame_step]): + signal = ops.convert_to_tensor(signal, name="signal") + signal.shape.with_rank_at_least(2) + frame_step = ops.convert_to_tensor(frame_step, name="frame_step") + frame_step.shape.assert_has_rank(0) + if not frame_step.dtype.is_integer: + raise ValueError("frame_step must be an integer. Got %s" % + frame_step.dtype) + + # If frame_length and frame_step are known at graph construction time, check + # frame_step is less than or equal to frame_length. + frame_step_static = tensor_util.constant_value(frame_step) + if (frame_step_static is not None and signal.shape.ndims is not None and + signal.shape[-1].value is not None and + frame_step_static > signal.shape[-1].value): + raise ValueError( + "frame_step (%d) must be less than or equal to frame_length (%d)" % ( + frame_step_static, signal.shape[-1].value)) + + signal_shape = array_ops.shape(signal) + + # All dimensions that are not part of the overlap-and-add. Can be empty for + # rank 2 inputs. + outer_dimensions = signal_shape[:-2] + + signal_rank = array_ops.rank(signal) + frames = signal_shape[-2] + frame_length = signal_shape[-1] + + subframe_length = util_ops.gcd(frame_length, frame_step) + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + # To avoid overlap-adding sample-by-sample, we overlap-add at the "subframe" + # level, where a subframe is gcd(frame_length, frame_step). Reshape signal + # from [..., frames, frame_length] into [..., subframes, subframe_length]. + subframe_shape = array_ops.concat( + [outer_dimensions, [-1, subframe_length]], 0) + subframe_signal = array_ops.reshape(signal, subframe_shape) + + # Now we shuffle the last [subframes, subframe_length] dimensions to the + # front. + # TODO(rjryan): Add an axis argument to unsorted_segment_sum so we can + # avoid this pair of transposes. + subframe_signal = _shuffle_to_front(subframe_signal, 2) + + # Use unsorted_segment_sum to add overlapping subframes together. + segment_ids = array_ops.reshape(shape_ops.frame( + math_ops.range(output_subframes), subframes_per_frame, subframe_step, + pad_end=False), [-1]) + result = math_ops.unsorted_segment_sum(subframe_signal, segment_ids, + num_segments=output_subframes) + + # result is a [subframes, subframe_length, ...outer_dimensions] tensor. We + # return a [...outer_dimensions, output_size] tensor with a transpose and + # reshape. + result_shape = array_ops.concat([outer_dimensions, [output_size]], 0) + return array_ops.reshape(_shuffle_to_front(result, signal_rank - 2), + result_shape) |