aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_test.py')
-rw-r--r--tensorflow/python/ops/nn_test.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 87f6f92a8a..cc8c623947 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -830,7 +830,8 @@ class ReluTest(test_lib.TestCase):
class MomentsTest(test_lib.TestCase):
- def doOutputTest(self, input_shape, moments_axes, tol=1e-4):
+ def doOutputTest(self, input_shape, moments_axes, tol=1e-4,
+ check_gradients=False):
for mu in [0.0, 1.0, 1e3]:
for sigma in [1.0, 0.1]:
for keep_dims in [True, False]:
@@ -846,6 +847,15 @@ class MomentsTest(test_lib.TestCase):
mean, variance = nn_impl.moments(
inputs, moments_axes, keep_dims=keep_dims)
+ if check_gradients:
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, mean, mean.shape.as_list())
+ self.assertLess(err, 1e-3)
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, variance, variance.shape.as_list())
+ self.assertLess(err, 1e-3)
+
+ # Evaluate.
[mean, variance] = sess.run([mean, variance])
# Make sure that there are no NaNs
self.assertFalse(np.isnan(mean).any())
@@ -853,6 +863,12 @@ class MomentsTest(test_lib.TestCase):
self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
+ def testOutputAndGradient2DInput0(self):
+ self.doOutputTest((10, 10), (0,), check_gradients=True)
+
+ def testOutputAndGradient2DInput01(self):
+ self.doOutputTest((10, 10), (0, 1), check_gradients=True)
+
def testOutput2DInput0(self):
self.doOutputTest((10, 300), (0,))