aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_test.py')
-rw-r--r--tensorflow/python/ops/nn_test.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 66bc0803b7..6767564024 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -878,11 +878,13 @@ class LeakyReluTest(test_lib.TestCase):
self.assertAllClose(inputs, outputs)
def testValues(self):
- np_values = np.array([-1.0, 0.0, 0.5, 1.0, 2.0], dtype=np.float32)
- outputs = nn_ops.leaky_relu(constant_op.constant(np_values))
- with self.test_session() as sess:
- outputs = sess.run(outputs)
- self.assertAllClose(outputs, [-0.2, 0.0, 0.5, 1.0, 2.0])
+ for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype)
+ outputs = nn_ops.leaky_relu(constant_op.constant(np_values))
+ with self.test_session() as sess:
+ outputs = sess.run(outputs)
+ tol = 2e-3 if dtype == np.float16 else 1e-6
+ self.assertAllClose(outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
class SwishTest(test_lib.TestCase):