aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-14 15:32:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-14 15:37:18 -0700
commitff5b2bf093575b852848b2d9a1fc31684a5800da (patch)
tree814de599089b07d715424fd31154c9202e8667a4 /tensorflow/contrib/signal
parentcc741b48e8bba90670fe9db3811d50f283ea8b1b (diff)
Add overlap_and_add to tf.contrib.signal.
This is the inverse operation of tf.contrib.signal.frame. This is implemented using GPU-capable ops and supports gradients. PiperOrigin-RevId: 162017464
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/BUILD18
-rw-r--r--tensorflow/contrib/signal/__init__.py2
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py192
-rw-r--r--tensorflow/contrib/signal/python/ops/reconstruction_ops.py144
4 files changed, 355 insertions, 1 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 6ed2313dca..706de58a0a 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -22,8 +22,24 @@ py_library(
)
cuda_py_tests(
+ name = "reconstruction_ops_test",
+ srcs = ["python/kernel_tests/reconstruction_ops_test.py"],
+ additional_deps = [
+ ":signal_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
name = "shape_ops_test",
- size = "small",
srcs = ["python/kernel_tests/shape_ops_test.py"],
additional_deps = [
":signal_py",
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index 780a090971..d0f6c1d0c6 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -17,12 +17,14 @@
@@frame
@@hamming_window
@@hann_window
+@@overlap_and_add
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add
from tensorflow.contrib.signal.python.ops.shape_ops import frame
# `frame` used to be named `frames`, which is a noun and not a verb.
# Keep an alias to `frames` for backwards compatibility.
diff --git a/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py
new file mode 100644
index 0000000000..5c9b2ac518
--- /dev/null
+++ b/tensorflow/contrib/signal/python/kernel_tests/reconstruction_ops_test.py
@@ -0,0 +1,192 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for reconstruction_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.signal.python.ops import reconstruction_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReconstructionOpsTest(test.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(ReconstructionOpsTest, self).__init__(*args, **kwargs)
+ self.batch_size = 3
+ self.frames = 3
+ self.samples = 5
+
+ self.bases = np.array(range(2, 5))
+ exponents = np.array(range(self.frames * self.samples))
+ powers = np.power(self.bases[:, np.newaxis], exponents[np.newaxis, :])
+
+ self.powers = np.reshape(powers, [self.batch_size, self.frames,
+ self.samples])
+ self.frame_hop = 2
+
+ # Hand computed example using powers of unique numbers: this is easily
+ # verified.
+ self.expected_string = ["1", "10", "100100", "1001000", "10010010000",
+ "100100000000", "1001000000000", "10000000000000",
+ "100000000000000"]
+
+ def test_all_ones(self):
+ signal = constant_op.constant(np.ones((3, 5)), dtype=dtypes.int64)
+ reconstruction = reconstruction_ops.overlap_and_add(signal, 2)
+
+ with self.test_session(use_gpu=True) as sess:
+ output = sess.run(reconstruction)
+
+ expected_output = np.array([1, 1, 2, 2, 3, 2, 2, 1, 1])
+
+ self.assertAllClose(output, expected_output)
+
+ def test_simple(self):
+ def make_input(frame_length, num_frames=3):
+ """Generate a tensor of num_frames frames of frame_length."""
+ return np.reshape(np.arange(1, num_frames * frame_length + 1),
+ (-1, frame_length))
+
+ # List of (signal, expected_result, frame_hop).
+ configurations = [
+ # All hop lengths on a frame length of 2.
+ (make_input(2), [1, 5, 9, 6], 1),
+ (make_input(2), [1, 2, 3, 4, 5, 6], 2),
+
+ # All hop lengths on a frame length of 3.
+ (make_input(3), [1, 6, 15, 14, 9], 1),
+ (make_input(3), [1, 2, 7, 5, 13, 8, 9], 2),
+ (make_input(3), [1, 2, 3, 4, 5, 6, 7, 8, 9], 3),
+
+ # All hop lengths on a frame length of 4.
+ (make_input(4), [1, 7, 18, 21, 19, 12], 1),
+ (make_input(4), [1, 2, 8, 10, 16, 18, 11, 12], 2),
+ (make_input(4), [1, 2, 3, 9, 6, 7, 17, 10, 11, 12], 3),
+ (make_input(4), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4),
+ ]
+
+ with self.test_session(use_gpu=True):
+ for signal, expected, frame_hop in configurations:
+ reconstruction = reconstruction_ops.overlap_and_add(
+ np.array(signal), frame_hop).eval()
+ expected_output = np.array(expected)
+ self.assertAllClose(reconstruction, expected_output)
+
+ def test_powers(self):
+ signal = constant_op.constant(np.squeeze(self.powers[0, :, :]),
+ dtype=dtypes.int64)
+ reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop)
+
+ with self.test_session(use_gpu=True) as sess:
+ output = sess.run(reconstruction)
+ string_output = [np.base_repr(x, self.bases[0]) for x in output]
+
+ self.assertEqual(string_output, self.expected_string)
+
+ def test_batch(self):
+ signal = constant_op.constant(self.powers, dtype=dtypes.int64)
+ reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop)
+
+ with self.test_session(use_gpu=True) as sess:
+ output = sess.run(reconstruction)
+
+ accumulator = True
+ for i in range(self.batch_size):
+ string_output = [np.base_repr(x, self.bases[i]) for x in output[i, :]]
+ accumulator = accumulator and (string_output == self.expected_string)
+
+ self.assertTrue(accumulator)
+
+ def test_one_element_batch(self):
+ input_matrix = np.squeeze(self.powers[0, :, :])
+ input_matrix = input_matrix[np.newaxis, :, :].astype(float)
+ signal = constant_op.constant(input_matrix, dtype=dtypes.float32)
+ reconstruction = reconstruction_ops.overlap_and_add(signal, self.frame_hop)
+
+ with self.test_session(use_gpu=True) as sess:
+ output = sess.run(reconstruction)
+
+ string_output = [np.base_repr(int(x), self.bases[0]) for x in
+ np.squeeze(output)]
+
+ self.assertEqual(output.shape, (1, 9))
+ self.assertEqual(string_output, self.expected_string)
+
+ def test_gradient(self):
+ configurations = [
+ ((1, 128), 1),
+ ((5, 35), 17),
+ ((10, 128), 128),
+ ((2, 10, 128), 127),
+ ((2, 2, 10, 128), 126),
+ ((2, 2, 2, 10, 128), 125),
+ ]
+
+ for shape, frame_hop in configurations:
+ with self.test_session(use_gpu=True) as sess:
+ signal = array_ops.zeros(shape)
+ reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop)
+ loss = math_ops.reduce_sum(reconstruction)
+ # Increasing any sample in the input frames by one will increase the sum
+ # of all the samples in the reconstruction by 1, so the gradient should
+ # be all ones, no matter the shape or hop.
+ gradient = sess.run(gradients_impl.gradients([loss], [signal])[0])
+ self.assertTrue((gradient == 1.0).all())
+
+ def test_gradient_batch(self):
+ with self.test_session(use_gpu=True) as sess:
+ signal = array_ops.zeros((2, 10, 10))
+ frame_hop = 10
+ reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop)
+
+ # Multiply the first batch-item's reconstruction by zeros. This will block
+ # gradient from flowing into the first batch item from the loss. Multiply
+ # the second batch item by the integers from 0 to 99. Since there is zero
+ # overlap, the gradient for this batch item will be 0-99 shaped as (10,
+ # 10).
+ reconstruction *= array_ops.stack(
+ [array_ops.zeros((100,)), math_ops.to_float(math_ops.range(100))])
+ loss = math_ops.reduce_sum(reconstruction)
+
+ # Verify that only the second batch item receives gradient.
+ gradient = sess.run(gradients_impl.gradients([loss], [signal])[0])
+ expected_gradient = np.stack([
+ np.zeros((10, 10)),
+ np.reshape(np.arange(100).astype(np.float32), (10, 10))])
+ self.assertAllEqual(expected_gradient, gradient)
+
+ def test_gradient_numerical(self):
+ with self.test_session(use_gpu=True):
+ shape = (2, 10, 10)
+ framed_signal = array_ops.zeros(shape)
+ frame_hop = 10
+ reconstruction = reconstruction_ops.overlap_and_add(
+ framed_signal, frame_hop)
+ error = test.compute_gradient_error(
+ framed_signal, shape, reconstruction, [2, 100])
+ self.assertLess(error, 2e-5)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
new file mode 100644
index 0000000000..f5f443ad09
--- /dev/null
+++ b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
@@ -0,0 +1,144 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Signal reconstruction via overlapped addition of frames."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.signal.python.ops import shape_ops
+from tensorflow.contrib.signal.python.ops import util_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def _shuffle_to_front(input_tensor, k):
+ """Shuffles the last `k` indices of `input_tensor` to the front.
+
+ Transposes `input_tensor` to have the last `k` indices at the front. The input
+ may have arbitrary rank and unknown shape.
+
+ Args:
+ input_tensor: A `Tensor` of arbitrary rank and unknown shape.
+ k: A scalar `Tensor` specifying how many indices to shuffle.
+
+ Returns:
+ A tranposed version of `input_tensor` with `k` indices shuffled to the
+ front.
+
+ Raises:
+ ValueError: If `input_tensor` is not at least rank `k` or `k` is not scalar.
+ """
+ k = ops.convert_to_tensor(k, name="k")
+ k.shape.with_rank(0)
+ k_static = tensor_util.constant_value(k)
+ if k_static is not None:
+ input_tensor.shape.with_rank_at_least(k_static)
+
+ rank = array_ops.rank(input_tensor)
+ outer_indices, inner_indices = array_ops.split(math_ops.range(rank),
+ [rank - k, k])
+ permutation = array_ops.concat([inner_indices, outer_indices], 0)
+
+ return array_ops.transpose(input_tensor, perm=permutation)
+
+
+def overlap_and_add(signal, frame_step, name=None):
+ """Reconstructs a signal from a framed representation.
+
+ Adds potentially overlapping frames of a signal with shape
+ `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
+ The resulting tensor has shape `[..., output_size]` where
+
+ output_size = (frames - 1) * frame_step + frame_length
+
+ Args:
+ signal: A [..., frames, frame_length] `Tensor`. All dimensions may be
+ unknown, and rank must be at least 2.
+ frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be
+ less than or equal to `frame_length`.
+ name: An optional name for the operation.
+
+ Returns:
+ A `Tensor` with shape `[..., output_size]` containing the overlap-added
+ frames of `signal`'s inner-most two dimensions.
+
+ Raises:
+ ValueError: If `signal`'s rank is less than 2, `frame_step` is not a scalar
+ integer or `frame_step` is greater than `frame_length`.
+ """
+ with ops.name_scope(name, "overlap_and_add", [signal, frame_step]):
+ signal = ops.convert_to_tensor(signal, name="signal")
+ signal.shape.with_rank_at_least(2)
+ frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
+ frame_step.shape.assert_has_rank(0)
+ if not frame_step.dtype.is_integer:
+ raise ValueError("frame_step must be an integer. Got %s" %
+ frame_step.dtype)
+
+ # If frame_length and frame_step are known at graph construction time, check
+ # frame_step is less than or equal to frame_length.
+ frame_step_static = tensor_util.constant_value(frame_step)
+ if (frame_step_static is not None and signal.shape.ndims is not None and
+ signal.shape[-1].value is not None and
+ frame_step_static > signal.shape[-1].value):
+ raise ValueError(
+ "frame_step (%d) must be less than or equal to frame_length (%d)" % (
+ frame_step_static, signal.shape[-1].value))
+
+ signal_shape = array_ops.shape(signal)
+
+ # All dimensions that are not part of the overlap-and-add. Can be empty for
+ # rank 2 inputs.
+ outer_dimensions = signal_shape[:-2]
+
+ signal_rank = array_ops.rank(signal)
+ frames = signal_shape[-2]
+ frame_length = signal_shape[-1]
+
+ subframe_length = util_ops.gcd(frame_length, frame_step)
+ subframe_step = frame_step // subframe_length
+ subframes_per_frame = frame_length // subframe_length
+ output_size = frame_step * (frames - 1) + frame_length
+ output_subframes = output_size // subframe_length
+
+ # To avoid overlap-adding sample-by-sample, we overlap-add at the "subframe"
+ # level, where a subframe is gcd(frame_length, frame_step). Reshape signal
+ # from [..., frames, frame_length] into [..., subframes, subframe_length].
+ subframe_shape = array_ops.concat(
+ [outer_dimensions, [-1, subframe_length]], 0)
+ subframe_signal = array_ops.reshape(signal, subframe_shape)
+
+ # Now we shuffle the last [subframes, subframe_length] dimensions to the
+ # front.
+ # TODO(rjryan): Add an axis argument to unsorted_segment_sum so we can
+ # avoid this pair of transposes.
+ subframe_signal = _shuffle_to_front(subframe_signal, 2)
+
+ # Use unsorted_segment_sum to add overlapping subframes together.
+ segment_ids = array_ops.reshape(shape_ops.frame(
+ math_ops.range(output_subframes), subframes_per_frame, subframe_step,
+ pad_end=False), [-1])
+ result = math_ops.unsorted_segment_sum(subframe_signal, segment_ids,
+ num_segments=output_subframes)
+
+ # result is a [subframes, subframe_length, ...outer_dimensions] tensor. We
+ # return a [...outer_dimensions, output_size] tensor with a transpose and
+ # reshape.
+ result_shape = array_ops.concat([outer_dimensions, [output_size]], 0)
+ return array_ops.reshape(_shuffle_to_front(result, signal_rank - 2),
+ result_shape)