aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/signal
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-25 10:53:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-25 11:06:10 -0700
commit879e207e815eacd0612da068c14f9aaf359a87e3 (patch)
treebea2b0b9fbf21ecbba196ed187c42f3fb1036aa3 /tensorflow/contrib/signal
parent4c9e344bf1b6582620b26c0a62a886d3c80e3c19 (diff)
Speed up tf.contrib.signal spectral_ops_test.py by reducing the size of the gradient test.
PiperOrigin-RevId: 163092423
Diffstat (limited to 'tensorflow/contrib/signal')
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
index be28184ae6..61b7107a17 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
@@ -220,15 +220,14 @@ class SpectralOpsTest(test.TestCase):
# stft_bound, inverse_stft_bound).
# TODO(rjryan): Investigate why STFT gradient error is so high.
test_configs = [
- (512, 64, 32, 64, 2e-3, 3e-5),
- (512, 64, 64, 64, 2e-3, 3e-5),
- (512, 64, 25, 64, 2e-3, 3e-5),
- (512, 25, 15, 36, 2e-3, 3e-5),
- (123, 23, 5, 42, 2e-3, 4e-5),
+ (64, 16, 8, 16),
+ (64, 16, 16, 16),
+ (64, 16, 7, 16),
+ (64, 7, 4, 9),
+ (29, 5, 1, 10),
]
- for (signal_length, frame_length, frame_step, fft_length,
- stft_bound, inverse_stft_bound) in test_configs:
+ for (signal_length, frame_length, frame_step, fft_length) in test_configs:
signal_shape = [signal_length]
signal = random_ops.random_uniform(signal_shape)
stft_shape = [max(0, 1 + (signal_length - frame_length) // frame_step),
@@ -242,8 +241,8 @@ class SpectralOpsTest(test.TestCase):
stft, stft_shape)
inverse_stft_error = test.compute_gradient_error(
stft, stft_shape, inverse_stft, inverse_stft_shape)
- self.assertLess(stft_error, stft_bound)
- self.assertLess(inverse_stft_error, inverse_stft_bound)
+ self.assertLess(stft_error, 2e-3)
+ self.assertLess(inverse_stft_error, 4e-5)
if __name__ == "__main__":