aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-08-29 17:05:43 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-08-29 17:05:43 +0800
commit4e72dd865a3fc83baa69f6b7c08720a1b546a464 (patch)
tree9ac73a11393bf248e4f85c64feafda9316081781 /tensorflow
parentcb5c61a3e11a37fb39a246aaf8ed6d02dd9ae9ab (diff)
Refine LeakyRelu codes.
1. Add C++ gradient of gradient definition of LeakyReLu and revalant UT. 2. Using forward compatibility layer for python code changes.
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc18
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc16
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py70
-rw-r--r--tensorflow/python/ops/nn_ops.py5
4 files changed, 73 insertions, 36 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index 0fc23d0bf7..2a32a2ed6f 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -149,13 +149,27 @@ Status LeakyReluGradHelper(const Scope& scope, const Operation& op,
float alpha;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
internal::LeakyReluGrad::Attrs attrs;
- attrs.Alpha(alpha);
- auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs);
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0),
+ attrs.Alpha(alpha));
grad_outputs->push_back(dx);
return scope.status();
}
REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper);
+Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ float alpha;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
+ internal::LeakyReluGrad::Attrs attrs;
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1),
+ attrs.Alpha(alpha));
+ grad_outputs->push_back(dx);
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper);
+
Status EluGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index 5ebece7b6e..bf0db1f59d 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
+#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -173,6 +174,21 @@ TEST_F(NNGradTest, LeakyReluGrad) {
RunTest(x, x_init_value, y, shape);
}
+TEST_F(NNGradTest, LeakyReluGradGrad) {
+ TensorShape shape({5, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ // Avoid input values where Leaky ReLU gradient is not well defined (around
+ // zero).
+ Tensor x_init_value = test::AsTensor<float>(
+ {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f},
+ {5, 2});
+ Tensor features = test::AsTensor<float>(
+ {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
+ {5, 2});
+ auto y = ops::internal::LeakyReluGrad(scope_, x, features);
+ RunTest(x, x_init_value, y, shape);
+}
+
TEST_F(NNGradTest, EluGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index ccb3a231bb..7066f28883 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -283,8 +284,9 @@ class LeakyReluTest(test.TestCase):
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
alpha=0.1, use_gpu=True)
- # The gradient test for ReLU is a bit tricky as the derivative is not well
- # defined at around zero and we want to avoid that in terms of input values.
+ # The gradient test for Leaky ReLU is a bit tricky as the derivative is not
+ # well defined at around zero and we want to avoid that in terms of input
+ # values.
def testGradientFloat32(self):
with self.test_session():
x = constant_op.constant(
@@ -319,39 +321,41 @@ class LeakyReluTest(test.TestCase):
self.assertLess(err, 1e-10)
def testGradGradFloat32(self):
- with self.test_session():
- x = constant_op.constant(
- [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
- shape=[2, 5],
- name="x")
- y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
- z = gradients_impl.gradients(y, x)
- x_init = np.asarray(
- [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
- dtype=np.float32,
- order="F")
- err = gradient_checker.compute_gradient_error(
- x, [2, 5], z[0], [2, 5], x_init_value=x_init)
- print("leaky_relu (float32) gradient of gradient err = ", err)
- self.assertLess(err, 1e-4)
+ with compat.forward_compatibility_horizon(2018, 10, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
- x = constant_op.constant(
- [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
- shape=[2, 5],
- dtype=dtypes.float64,
- name="x")
- y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
- z = gradients_impl.gradients(y, x)
- x_init = np.asarray(
- [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
- dtype=np.float64,
- order="F")
- err = gradient_checker.compute_gradient_error(
- x, [2, 5], z[0], [2, 5], x_init_value=x_init)
- print("leaky_relu (float64) gradient of gradient err = ", err)
- self.assertLess(err, 1e-10)
+ with compat.forward_compatibility_horizon(2018, 10, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-10)
def testGradientScalar(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 31b8f3945d..52ea202636 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1601,7 +1601,10 @@ def leaky_relu(features, alpha=0.2, name=None):
features = ops.convert_to_tensor(features, name="features")
if features.dtype.is_integer:
features = math_ops.to_float(features)
- return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
+ if compat.forward_compatible(2018, 10, 1):
+ return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
+ alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
+ return math_ops.maximum(alpha * features, features, name=name)
def _flatten_outer_dims(logits):