aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2017-05-05 09:09:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-05 10:26:00 -0700
commit692fad20f913ffa2cb874a87578ecabb03cc4557 (patch)
tree172717f537c91b0d1ac0366731b4eb2093fb743b /tensorflow/contrib/signal
parentb329dd821e29e64c93b1b9bf38e61871c6cb53da (diff)
Merge changes from github.
Change: 155209832
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/BUILD46
-rw-r--r--tensorflow/contrib/signal/__init__.py27
-rw-r--r--tensorflow/contrib/signal/python/__init__.py19
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py68
-rw-r--r--tensorflow/contrib/signal/python/ops/__init__.py19
-rw-r--r--tensorflow/contrib/signal/python/ops/shape_ops.py87
6 files changed, 266 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
new file mode 100644
index 0000000000..5b65a6ae05
--- /dev/null
+++ b/tensorflow/contrib/signal/BUILD
@@ -0,0 +1,46 @@
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+
+py_library(
+ name = "signal_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+cuda_py_tests(
+ name = "shape_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/shape_ops_test.py"],
+ additional_deps = [
+ ":signal_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
new file mode 100644
index 0000000000..9f906dd28e
--- /dev/null
+++ b/tensorflow/contrib/signal/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""##Signal ops.
+
+@@frames
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.signal.python.ops.shape_ops import frames
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/signal/python/__init__.py b/tensorflow/contrib/signal/python/__init__.py
new file mode 100644
index 0000000000..e672d1146c
--- /dev/null
+++ b/tensorflow/contrib/signal/python/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Signal ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
new file mode 100644
index 0000000000..e07942875f
--- /dev/null
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -0,0 +1,68 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for shape_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.signal.python.ops import shape_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class FramesTest(test.TestCase):
+
+ def test_mapping_of_indices_without_padding(self):
+ with self.test_session():
+ tensor = constant_op.constant(np.arange(9152), dtypes.int32)
+ tensor = array_ops.expand_dims(tensor, 0)
+
+ result = shape_ops.frames(tensor, 512, 180)
+ result = result.eval()
+
+ expected = np.tile(np.arange(512), (49, 1))
+ expected += np.tile(np.arange(49) * 180, (512, 1)).T
+
+ expected = np.expand_dims(expected, axis=0)
+ expected = np.array(expected, dtype=np.int32)
+
+ self.assertAllEqual(expected, result)
+
+ def test_mapping_of_indices_with_padding(self):
+ with self.test_session():
+ tensor = constant_op.constant(np.arange(10000), dtypes.int32)
+ tensor = array_ops.expand_dims(tensor, 0)
+
+ result = shape_ops.frames(tensor, 512, 192)
+ result = result.eval()
+
+ expected = np.tile(np.arange(512), (51, 1))
+ expected += np.tile(np.arange(51) * 192, (512, 1)).T
+
+ expected[expected >= 10000] = 0
+
+ expected = np.expand_dims(expected, axis=0)
+ expected = np.array(expected, dtype=np.int32)
+
+ self.assertAllEqual(expected, result)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/signal/python/ops/__init__.py b/tensorflow/contrib/signal/python/ops/__init__.py
new file mode 100644
index 0000000000..e672d1146c
--- /dev/null
+++ b/tensorflow/contrib/signal/python/ops/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Signal ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py
new file mode 100644
index 0000000000..4914f19be7
--- /dev/null
+++ b/tensorflow/contrib/signal/python/ops/shape_ops.py
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""General shape ops for frames."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def frames(signal, frame_length, frame_step, name=None):
+ """Frame a signal into overlapping frames.
+
+ May be used in front of spectral functions.
+
+ For example:
+
+ ```python
+ pcm = tf.placeholder(tf.float32, [None, 9152])
+ frames = tf.contrib.signal.frames(pcm, 512, 180)
+ magspec = tf.abs(tf.spectral.rfft(frames, [512]))
+ image = tf.expand_dims(magspec, 3)
+ ```
+
+ Args:
+ signal: A `Tensor` of shape `[batch_size, signal_length]`.
+ frame_length: An `int32` or `int64` `Tensor`. The length of each frame.
+ frame_step: An `int32` or `int64` `Tensor`. The step between frames.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
+
+ Raises:
+ ValueError: if signal does not have rank 2.
+ """
+ with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
+ signal = ops.convert_to_tensor(signal, name="signal")
+ frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
+ frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
+
+ signal_rank = signal.shape.ndims
+
+ if signal_rank != 2:
+ raise ValueError("expected signal to have rank 2 but was " + signal_rank)
+
+ signal_length = array_ops.shape(signal)[1]
+
+ num_frames = math_ops.ceil((signal_length - frame_length) / frame_step)
+ num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
+
+ pad_length = (num_frames - 1) * frame_step + frame_length
+ pad_signal = array_ops.pad(signal, [[0, 0], [0,
+ pad_length - signal_length]])
+
+ indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
+ indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
+
+ indices_step = array_ops.expand_dims(
+ math_ops.range(num_frames) * frame_step, 1)
+ indices_steps = array_ops.tile(indices_step, [1, frame_length])
+
+ indices = indices_frames + indices_steps
+
+ # TODO(androbin): remove `transpose` when `gather` gets `axis` support
+ pad_signal = array_ops.transpose(pad_signal)
+ signal_frames = array_ops.gather(pad_signal, indices)
+ signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1])
+
+ return signal_frames