aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/relu_op_test.py
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-08-29 17:05:43 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-08-29 17:05:43 +0800
commit4e72dd865a3fc83baa69f6b7c08720a1b546a464 (patch)
tree9ac73a11393bf248e4f85c64feafda9316081781 /tensorflow/python/kernel_tests/relu_op_test.py
parentcb5c61a3e11a37fb39a246aaf8ed6d02dd9ae9ab (diff)
Refine LeakyRelu codes.
1. Add C++ gradient of gradient definition of LeakyReLu and revalant UT. 2. Using forward compatibility layer for python code changes.
Diffstat (limited to 'tensorflow/python/kernel_tests/relu_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py70
1 files changed, 37 insertions, 33 deletions
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index ccb3a231bb..7066f28883 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -283,8 +284,9 @@ class LeakyReluTest(test.TestCase):
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
alpha=0.1, use_gpu=True)
- # The gradient test for ReLU is a bit tricky as the derivative is not well
- # defined at around zero and we want to avoid that in terms of input values.
+ # The gradient test for Leaky ReLU is a bit tricky as the derivative is not
+ # well defined at around zero and we want to avoid that in terms of input
+ # values.
def testGradientFloat32(self):
with self.test_session():
x = constant_op.constant(
@@ -319,39 +321,41 @@ class LeakyReluTest(test.TestCase):
self.assertLess(err, 1e-10)
def testGradGradFloat32(self):
- with self.test_session():
- x = constant_op.constant(
- [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
- shape=[2, 5],
- name="x")
- y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
- z = gradients_impl.gradients(y, x)
- x_init = np.asarray(
- [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
- dtype=np.float32,
- order="F")
- err = gradient_checker.compute_gradient_error(
- x, [2, 5], z[0], [2, 5], x_init_value=x_init)
- print("leaky_relu (float32) gradient of gradient err = ", err)
- self.assertLess(err, 1e-4)
+ with compat.forward_compatibility_horizon(2018, 10, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
- x = constant_op.constant(
- [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
- shape=[2, 5],
- dtype=dtypes.float64,
- name="x")
- y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
- z = gradients_impl.gradients(y, x)
- x_init = np.asarray(
- [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
- dtype=np.float64,
- order="F")
- err = gradient_checker.compute_gradient_error(
- x, [2, 5], z[0], [2, 5], x_init_value=x_init)
- print("leaky_relu (float64) gradient of gradient err = ", err)
- self.assertLess(err, 1e-10)
+ with compat.forward_compatibility_horizon(2018, 10, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-10)
def testGradientScalar(self):
with self.test_session() as sess: