aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-14 15:15:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-14 15:26:16 -0800
commitab98b2e4c378e62e9d7a4fbc1fda090083db0bcc (patch)
tree5bf7dcd2fb8b29e4fc9d39f0bba7505fcf6821a8
parent405d37a0de7a904a26da3a4810451fc70e93fb67 (diff)
Create test that validates relu(nan) = nan
Change: 147527599
-rw-r--r--tensorflow/python/ops/nn_test.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index f48e420e87..fdb036ebfd 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -788,13 +788,33 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
class CReluTest(test_lib.TestCase):
def test(self):
- x = np.random.rand(3, 4).astype(np.float32)
+ np.random.seed(1) # Make it reproducible.
+ x = np.random.randn(3, 4).astype(np.float32)
y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1)
with self.test_session():
z = nn_ops.crelu(constant_op.constant(x)).eval()
self.assertAllClose(y, z, 1e-4)
+class ReluTest(test_lib.TestCase):
+
+ def test(self):
+ np.random.seed(1) # Make it reproducible.
+ x = np.random.randn(3, 4).astype(np.float32)
+ y = np.maximum(x, 0.0)
+ with self.test_session():
+ z = nn_ops.relu(constant_op.constant(x)).eval()
+ self.assertAllEqual(y, z)
+
+ def testNaNs(self):
+ # Test that relu(nan) = nan for various sizes.
+ for i in range(18):
+ x = np.zeros(i) + np.nan
+ with self.test_session():
+ z = nn_ops.relu(constant_op.constant(x)).eval()
+ self.assertTrue(np.isnan(z).all())
+
+
class MomentsTest(test_lib.TestCase):
def doOutputTest(self, input_shape, moments_axes, tol=1e-4):