diff options
author | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-08-29 17:05:43 +0800 |
---|---|---|
committer | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-08-29 17:05:43 +0800 |
commit | 4e72dd865a3fc83baa69f6b7c08720a1b546a464 (patch) | |
tree | 9ac73a11393bf248e4f85c64feafda9316081781 /tensorflow/python/kernel_tests/relu_op_test.py | |
parent | cb5c61a3e11a37fb39a246aaf8ed6d02dd9ae9ab (diff) |
Refine LeakyRelu codes.
1. Add C++ gradient of gradient definition of LeakyReLu and revalant UT.
2. Using forward compatibility layer for python code changes.
Diffstat (limited to 'tensorflow/python/kernel_tests/relu_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/relu_op_test.py | 70 |
1 files changed, 37 insertions, 33 deletions
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index ccb3a231bb..7066f28883 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -283,8 +284,9 @@ class LeakyReluTest(test.TestCase): np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), alpha=0.1, use_gpu=True) - # The gradient test for ReLU is a bit tricky as the derivative is not well - # defined at around zero and we want to avoid that in terms of input values. + # The gradient test for Leaky ReLU is a bit tricky as the derivative is not + # well defined at around zero and we want to avoid that in terms of input + # values. def testGradientFloat32(self): with self.test_session(): x = constant_op.constant( @@ -319,39 +321,41 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradGradFloat32(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - name="x") - y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float32, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float32) gradient of gradient err = ", err) - self.assertLess(err, 1e-4) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) def testGradGradFloat64(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - dtype=dtypes.float64, - name="x") - y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float64, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float64) gradient of gradient err = ", err) - self.assertLess(err, 1e-10) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-10) def testGradientScalar(self): with self.test_session() as sess: |