diff options
Diffstat (limited to 'tensorflow/python/ops/nn_test.py')
-rw-r--r-- | tensorflow/python/ops/nn_test.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 66bc0803b7..6767564024 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -878,11 +878,13 @@ class LeakyReluTest(test_lib.TestCase): self.assertAllClose(inputs, outputs) def testValues(self): - np_values = np.array([-1.0, 0.0, 0.5, 1.0, 2.0], dtype=np.float32) - outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) - with self.test_session() as sess: - outputs = sess.run(outputs) - self.assertAllClose(outputs, [-0.2, 0.0, 0.5, 1.0, 2.0]) + for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: + np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) + outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) + with self.test_session() as sess: + outputs = sess.run(outputs) + tol = 2e-3 if dtype == np.float16 else 1e-6 + self.assertAllClose(outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) class SwishTest(test_lib.TestCase): |