diff options
Diffstat (limited to 'tensorflow/python/ops/nn_test.py')
-rw-r--r-- | tensorflow/python/ops/nn_test.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 87f6f92a8a..cc8c623947 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -830,7 +830,8 @@ class ReluTest(test_lib.TestCase): class MomentsTest(test_lib.TestCase): - def doOutputTest(self, input_shape, moments_axes, tol=1e-4): + def doOutputTest(self, input_shape, moments_axes, tol=1e-4, + check_gradients=False): for mu in [0.0, 1.0, 1e3]: for sigma in [1.0, 0.1]: for keep_dims in [True, False]: @@ -846,6 +847,15 @@ class MomentsTest(test_lib.TestCase): mean, variance = nn_impl.moments( inputs, moments_axes, keep_dims=keep_dims) + if check_gradients: + err = gradient_checker.compute_gradient_error( + inputs, input_shape, mean, mean.shape.as_list()) + self.assertLess(err, 1e-3) + err = gradient_checker.compute_gradient_error( + inputs, input_shape, variance, variance.shape.as_list()) + self.assertLess(err, 1e-3) + + # Evaluate. [mean, variance] = sess.run([mean, variance]) # Make sure that there are no NaNs self.assertFalse(np.isnan(mean).any()) @@ -853,6 +863,12 @@ class MomentsTest(test_lib.TestCase): self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) + def testOutputAndGradient2DInput0(self): + self.doOutputTest((10, 10), (0,), check_gradients=True) + + def testOutputAndGradient2DInput01(self): + self.doOutputTest((10, 10), (0, 1), check_gradients=True) + def testOutput2DInput0(self): self.doOutputTest((10, 300), (0,)) |