diff options
author | 2017-02-14 15:15:13 -0800 | |
---|---|---|
committer | 2017-02-14 15:26:16 -0800 | |
commit | ab98b2e4c378e62e9d7a4fbc1fda090083db0bcc (patch) | |
tree | 5bf7dcd2fb8b29e4fc9d39f0bba7505fcf6821a8 | |
parent | 405d37a0de7a904a26da3a4810451fc70e93fb67 (diff) |
Create test that validates relu(nan) = nan
Change: 147527599
-rw-r--r-- | tensorflow/python/ops/nn_test.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index f48e420e87..fdb036ebfd 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -788,13 +788,33 @@ class ComputeSampledLogitsTest(test_lib.TestCase): class CReluTest(test_lib.TestCase): def test(self): - x = np.random.rand(3, 4).astype(np.float32) + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1) with self.test_session(): z = nn_ops.crelu(constant_op.constant(x)).eval() self.assertAllClose(y, z, 1e-4) +class ReluTest(test_lib.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = np.maximum(x, 0.0) + with self.test_session(): + z = nn_ops.relu(constant_op.constant(x)).eval() + self.assertAllEqual(y, z) + + def testNaNs(self): + # Test that relu(nan) = nan for various sizes. + for i in range(18): + x = np.zeros(i) + np.nan + with self.test_session(): + z = nn_ops.relu(constant_op.constant(x)).eval() + self.assertTrue(np.isnan(z).all()) + + class MomentsTest(test_lib.TestCase): def doOutputTest(self, input_shape, moments_axes, tol=1e-4): |