diff options
author | RJ Ryan <rjryan@google.com> | 2017-07-25 10:53:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-25 11:06:10 -0700 |
commit | 879e207e815eacd0612da068c14f9aaf359a87e3 (patch) | |
tree | bea2b0b9fbf21ecbba196ed187c42f3fb1036aa3 /tensorflow/contrib/signal | |
parent | 4c9e344bf1b6582620b26c0a62a886d3c80e3c19 (diff) |
Speed up tf.contrib.signal spectral_ops_test.py by reducing the size of the gradient test.
PiperOrigin-RevId: 163092423
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r-- | tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index be28184ae6..61b7107a17 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -220,15 +220,14 @@ class SpectralOpsTest(test.TestCase): # stft_bound, inverse_stft_bound). # TODO(rjryan): Investigate why STFT gradient error is so high. test_configs = [ - (512, 64, 32, 64, 2e-3, 3e-5), - (512, 64, 64, 64, 2e-3, 3e-5), - (512, 64, 25, 64, 2e-3, 3e-5), - (512, 25, 15, 36, 2e-3, 3e-5), - (123, 23, 5, 42, 2e-3, 4e-5), + (64, 16, 8, 16), + (64, 16, 16, 16), + (64, 16, 7, 16), + (64, 7, 4, 9), + (29, 5, 1, 10), ] - for (signal_length, frame_length, frame_step, fft_length, - stft_bound, inverse_stft_bound) in test_configs: + for (signal_length, frame_length, frame_step, fft_length) in test_configs: signal_shape = [signal_length] signal = random_ops.random_uniform(signal_shape) stft_shape = [max(0, 1 + (signal_length - frame_length) // frame_step), @@ -242,8 +241,8 @@ class SpectralOpsTest(test.TestCase): stft, stft_shape) inverse_stft_error = test.compute_gradient_error( stft, stft_shape, inverse_stft, inverse_stft_shape) - self.assertLess(stft_error, stft_bound) - self.assertLess(inverse_stft_error, inverse_stft_bound) + self.assertLess(stft_error, 2e-3) + self.assertLess(inverse_stft_error, 4e-5) if __name__ == "__main__": |