diff options
author | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-08-29 17:05:43 +0800 |
---|---|---|
committer | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-08-29 17:05:43 +0800 |
commit | 4e72dd865a3fc83baa69f6b7c08720a1b546a464 (patch) | |
tree | 9ac73a11393bf248e4f85c64feafda9316081781 | |
parent | cb5c61a3e11a37fb39a246aaf8ed6d02dd9ae9ab (diff) |
Refine LeakyRelu codes.
1. Add C++ gradient of gradient definition of LeakyReLu and revalant UT.
2. Using forward compatibility layer for python code changes.
-rw-r--r-- | tensorflow/cc/gradients/nn_grad.cc | 18 | ||||
-rw-r--r-- | tensorflow/cc/gradients/nn_grad_test.cc | 16 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/relu_op_test.py | 70 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 5 |
4 files changed, 73 insertions, 36 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0fc23d0bf7..2a32a2ed6f 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -149,13 +149,27 @@ Status LeakyReluGradHelper(const Scope& scope, const Operation& op, float alpha; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); internal::LeakyReluGrad::Attrs attrs; - attrs.Alpha(alpha); - auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs); + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); grad_outputs->push_back(dx); return scope.status(); } REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 5ebece7b6e..bf0db1f59d 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -173,6 +174,21 @@ TEST_F(NNGradTest, LeakyReluGrad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor<float>( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, + {5, 2}); + Tensor features = test::AsTensor<float>( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index ccb3a231bb..7066f28883 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -283,8 +284,9 @@ class LeakyReluTest(test.TestCase): np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), alpha=0.1, use_gpu=True) - # The gradient test for ReLU is a bit tricky as the derivative is not well - # defined at around zero and we want to avoid that in terms of input values. + # The gradient test for Leaky ReLU is a bit tricky as the derivative is not + # well defined at around zero and we want to avoid that in terms of input + # values. def testGradientFloat32(self): with self.test_session(): x = constant_op.constant( @@ -319,39 +321,41 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradGradFloat32(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - name="x") - y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float32, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float32) gradient of gradient err = ", err) - self.assertLess(err, 1e-4) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) def testGradGradFloat64(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - dtype=dtypes.float64, - name="x") - y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float64, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float64) gradient of gradient err = ", err) - self.assertLess(err, 1e-10) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-10) def testGradientScalar(self): with self.test_session() as sess: diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 31b8f3945d..52ea202636 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1601,7 +1601,10 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) - return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) + if compat.forward_compatible(2018, 10, 1): + return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) + alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") + return math_ops.maximum(alpha * features, features, name=name) def _flatten_outer_dims(logits): |