aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-10-02 17:01:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 18:30:26 -0700
commit991dea6bedd41e27590c29212855c89a09b2bfb3 (patch)
tree9d42ec69f17488cca5da75d7d91977720a6e0331 /tensorflow/contrib/signal
parent931268a690ab9fd875962945af0c7a66b8b5d9fe (diff)
[tf-signal] Add a test that windowing, framing, and mel ops are constant foldable for constant inputs.
PiperOrigin-RevId: 170777731
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/BUILD14
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py11
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py16
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/test_util.py46
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py13
5 files changed, 100 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 6025ec5b57..80bcb9632e 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -24,11 +24,23 @@ py_library(
],
)
+py_library(
+ name = "test_util",
+ srcs = ["python/kernel_tests/test_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:tf_optimizer",
+ ],
+)
+
cuda_py_tests(
name = "mel_ops_test",
srcs = ["python/kernel_tests/mel_ops_test.py"],
additional_deps = [
":signal_py",
+ ":test_util",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
@@ -70,6 +82,7 @@ cuda_py_tests(
srcs = ["python/kernel_tests/shape_ops_test.py"],
additional_deps = [
":signal_py",
+ ":test_util",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:math_ops",
@@ -107,6 +120,7 @@ cuda_py_tests(
srcs = ["python/kernel_tests/window_ops_test.py"],
additional_deps = [
":signal_py",
+ ":test_util",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
index f107b53f01..b861476b67 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.signal.python.kernel_tests import test_util
from tensorflow.contrib.signal.python.ops import mel_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.platform import test
# mel spectrum constants and functions.
@@ -159,6 +161,15 @@ class LinearToMelTest(test.TestCase):
with self.assertRaises(ValueError):
mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32)
+ def test_constant_folding(self):
+ """Mel functions should be constant foldable."""
+ for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
+ g = ops.Graph()
+ with g.as_default():
+ mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype)
+ rewritten_graph = test_util.grappler_optimize(g, [mel_matrix])
+ self.assertEqual(1, len(rewritten_graph.node))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
index 8633ced599..1c052354b8 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -20,9 +20,11 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.signal.python.kernel_tests import test_util
from tensorflow.contrib.signal.python.ops import shape_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -334,5 +336,19 @@ class FrameTest(test.TestCase):
signal, signal_shape, frames, frames.shape.as_list())
self.assertLess(error, 2e-5)
+ def test_constant_folding(self):
+ """frame should be constant foldable for constant inputs."""
+ for pad_end in [False, True]:
+ g = ops.Graph()
+ with g.as_default():
+ frame_length, frame_step = 32, 16
+ signal_shape = (2, 128)
+ signal = array_ops.ones(signal_shape)
+ frames = shape_ops.frame(signal, frame_length, frame_step,
+ pad_end=pad_end)
+ rewritten_graph = test_util.grappler_optimize(g, [frames])
+ self.assertEqual(1, len(rewritten_graph.node))
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py
new file mode 100644
index 0000000000..9a3603b6a9
--- /dev/null
+++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.contrib.signal."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.training import saver
+
+
+def grappler_optimize(graph, fetches=None, rewriter_config=None):
+ """Tries to optimize the provided graph using grappler.
+
+ Args:
+ graph: A @{tf.Graph} instance containing the graph to optimize.
+ fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away).
+ Grappler uses the 'train_op' collection to look for fetches, so if not
+ provided this collection should be non-empty.
+ rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the
+ graph.
+
+ Returns:
+ A @{tf.GraphDef} containing the rewritten graph.
+ """
+ if rewriter_config is None:
+ rewriter_config = rewriter_config_pb2.RewriterConfig()
+ if fetches is not None:
+ for fetch in fetches:
+ graph.add_to_collection('train_op', fetch)
+ metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
+ return tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
diff --git a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py
index c3e0464596..5a464699da 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py
@@ -22,8 +22,10 @@ import functools
import numpy as np
+from tensorflow.contrib.signal.python.kernel_tests import test_util
from tensorflow.contrib.signal.python.ops import window_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@@ -91,6 +93,17 @@ class WindowOpsTest(test.TestCase):
functools.partial(_scipy_raised_cosine, a=0.54, b=0.46),
window_ops.hamming_window)
+ def test_constant_folding(self):
+ """Window functions should be constant foldable for constant inputs."""
+ for window_fn in (window_ops.hann_window, window_ops.hamming_window):
+ for dtype, _ in self._dtypes:
+ for periodic in [False, True]:
+ g = ops.Graph()
+ with g.as_default():
+ window = window_fn(100, periodic=periodic, dtype=dtype)
+ rewritten_graph = test_util.grappler_optimize(g, [window])
+ self.assertEqual(1, len(rewritten_graph.node))
+
if __name__ == '__main__':
test.main()