diff options
author | 2017-10-02 17:01:17 -0700 | |
---|---|---|
committer | 2017-10-02 18:30:26 -0700 | |
commit | 991dea6bedd41e27590c29212855c89a09b2bfb3 (patch) | |
tree | 9d42ec69f17488cca5da75d7d91977720a6e0331 /tensorflow/contrib/signal | |
parent | 931268a690ab9fd875962945af0c7a66b8b5d9fe (diff) |
[tf-signal] Add a test that windowing, framing, and mel ops are constant foldable for constant inputs.
PiperOrigin-RevId: 170777731
Diffstat (limited to 'tensorflow/contrib/signal')
5 files changed, 100 insertions, 0 deletions
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 6025ec5b57..80bcb9632e 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -24,11 +24,23 @@ py_library( ], ) +py_library( + name = "test_util", + srcs = ["python/kernel_tests/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework", + "//tensorflow/python:tf_optimizer", + ], +) + cuda_py_tests( name = "mel_ops_test", srcs = ["python/kernel_tests/mel_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:client_testlib", ], @@ -70,6 +82,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/shape_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", @@ -107,6 +120,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/window_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py index f107b53f01..b861476b67 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import mel_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.platform import test # mel spectrum constants and functions. @@ -159,6 +161,15 @@ class LinearToMelTest(test.TestCase): with self.assertRaises(ValueError): mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32) + def test_constant_folding(self): + """Mel functions should be constant foldable.""" + for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + g = ops.Graph() + with g.as_default(): + mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype) + rewritten_graph = test_util.grappler_optimize(g, [mel_matrix]) + self.assertEqual(1, len(rewritten_graph.node)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py index 8633ced599..1c052354b8 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py @@ -20,9 +20,11 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import shape_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -334,5 +336,19 @@ class FrameTest(test.TestCase): signal, signal_shape, frames, frames.shape.as_list()) self.assertLess(error, 2e-5) + def test_constant_folding(self): + """frame should be constant foldable for constant inputs.""" + for pad_end in [False, True]: + g = ops.Graph() + with g.as_default(): + frame_length, frame_step = 32, 16 + signal_shape = (2, 128) + signal = array_ops.ones(signal_shape) + frames = shape_ops.frame(signal, frame_length, frame_step, + pad_end=pad_end) + rewritten_graph = test_util.grappler_optimize(g, [frames]) + self.assertEqual(1, len(rewritten_graph.node)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py new file mode 100644 index 0000000000..9a3603b6a9 --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities for tf.contrib.signal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.training import saver + + +def grappler_optimize(graph, fetches=None, rewriter_config=None): + """Tries to optimize the provided graph using grappler. + + Args: + graph: A @{tf.Graph} instance containing the graph to optimize. + fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away). + Grappler uses the 'train_op' collection to look for fetches, so if not + provided this collection should be non-empty. + rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the + graph. + + Returns: + A @{tf.GraphDef} containing the rewritten graph. + """ + if rewriter_config is None: + rewriter_config = rewriter_config_pb2.RewriterConfig() + if fetches is not None: + for fetch in fetches: + graph.add_to_collection('train_op', fetch) + metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def()) + return tf_optimizer.OptimizeGraph(rewriter_config, metagraph) diff --git a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py index c3e0464596..5a464699da 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py @@ -22,8 +22,10 @@ import functools import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import window_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -91,6 +93,17 @@ class WindowOpsTest(test.TestCase): functools.partial(_scipy_raised_cosine, a=0.54, b=0.46), window_ops.hamming_window) + def test_constant_folding(self): + """Window functions should be constant foldable for constant inputs.""" + for window_fn in (window_ops.hann_window, window_ops.hamming_window): + for dtype, _ in self._dtypes: + for periodic in [False, True]: + g = ops.Graph() + with g.as_default(): + window = window_fn(100, periodic=periodic, dtype=dtype) + rewritten_graph = test_util.grappler_optimize(g, [window]) + self.assertEqual(1, len(rewritten_graph.node)) + if __name__ == '__main__': test.main() |