diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 11:29:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 11:29:04 -0700 |
commit | 96237f7b7ae6b7b8a2cbcf6d64312906b96f060b (patch) | |
tree | a96bb853e59dc37e90e4f8fde229f4d88b3f225a /tensorflow/python | |
parent | 3f0155133d668cf6cee1f1fb362d2a75c04836e3 (diff) | |
parent | 96eec07af06f4dfc75cee57b74ba4b5347619634 (diff) |
Merge pull request #21658 from lowintelligence:master
PiperOrigin-RevId: 216217509
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/relu_op_test.py | 120 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 15 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 3 |
4 files changed, 140 insertions, 0 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 6d3ef9a37b..9789dbadee 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1836,6 +1836,8 @@ bool OpGradientDoesntRequireOutputIndices( {"SoftplusGrad", {true, {}}}, {"Softsign", {true, {}}}, {"ReluGrad", {true, {}}}, + {"LeakyRelu", {true, {}}}, + {"LeakyReluGrad", {true, {}}}, {"Conv2D", {true, {}}}, {"DepthwiseConv2dNative", {true, {}}}, {"Dilation2D", {true, {}}}, diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index a45a325b47..672d6556f5 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -282,6 +283,125 @@ class Relu6Test(test.TestCase): self.assertLess(err, 1e-10) +class LeakyReluTest(test.TestCase): + + def _npLeakyRelu(self, np_features, alpha=0.1): + return np.maximum(np_features, alpha * np_features) + + def testNpLeakyRelu(self): + self.assertAllClose( + np.array([[-0.09, 0.7, -0.05, 0.3, -0.01], + [0.1, -0.03, 0.5, -0.07, 0.9]]), + self._npLeakyRelu( + np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, + 0.9]]), + alpha=0.1)) + + def _testLeakyRelu(self, np_features, alpha, use_gpu=False): + np_leaky_relu = self._npLeakyRelu(np_features, alpha) + with self.test_session(use_gpu=use_gpu): + leaky_relu = nn_ops.leaky_relu(np_features, alpha) + tf_leaky_relu = leaky_relu.eval() + self.assertAllClose(np_leaky_relu, tf_leaky_relu) + self.assertShapeEqual(np_leaky_relu, leaky_relu) + + def testNumbers(self): + for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: + self._testLeakyRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + alpha=0.2, + use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._testLeakyRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + alpha=0.1, + use_gpu=True) + + # The gradient test for Leaky ReLU is a bit tricky as the derivative is not + # well defined at around zero and we want to avoid that in terms of input + # values. + def testGradientFloat32(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradientFloat64(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.2, name="leaky_relu") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient err = ", err) + self.assertLess(err, 1e-10) + + def testGradGradFloat32(self): + with compat.forward_compatibility_horizon(2018, 11, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradGradFloat64(self): + with compat.forward_compatibility_horizon(2018, 11, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-10) + + def testGradientScalar(self): + with self.test_session() as sess: + x = variables.Variable(-100.) + y = nn_ops.leaky_relu(x, 0.05) + loss = y**2 + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2) + train_op = optimizer.minimize(loss) + sess.run(variables.global_variables_initializer()) + sess.run(train_op) + self.assertAllClose(x.eval(), -99.9) + + class EluTest(test.TestCase): def _npElu(self, np_features): diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index e1a01ab4c3..902653befc 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -389,6 +389,21 @@ def _Relu6GradGrad(op, grad): array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) +@ops.RegisterGradient("LeakyRelu") +def _LeakyReluGrad(op, grad): + x = op.inputs[0] + alpha = op.get_attr("alpha") + return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) + + +@ops.RegisterGradient("LeakyReluGrad") +def _LeakyReluGradGrad(op, grad): + x = op.inputs[1] + alpha = op.get_attr("alpha") + return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + + @ops.RegisterGradient("Elu") def _EluGrad(op, grad): return gen_nn_ops.elu_grad(grad, op.outputs[0]) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 1fbe31a098..04962da7f7 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -22,6 +22,7 @@ import numbers import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_util @@ -1602,6 +1603,8 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) + if compat.forward_compatible(2018, 11, 1): + return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") return math_ops.maximum(alpha * features, features, name=name) |