aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 11:29:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 11:29:04 -0700
commit96237f7b7ae6b7b8a2cbcf6d64312906b96f060b (patch)
treea96bb853e59dc37e90e4f8fde229f4d88b3f225a /tensorflow/python
parent3f0155133d668cf6cee1f1fb362d2a75c04836e3 (diff)
parent96eec07af06f4dfc75cee57b74ba4b5347619634 (diff)
Merge pull request #21658 from lowintelligence:master
PiperOrigin-RevId: 216217509
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc2
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py120
-rw-r--r--tensorflow/python/ops/nn_grad.py15
-rw-r--r--tensorflow/python/ops/nn_ops.py3
4 files changed, 140 insertions, 0 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6d3ef9a37b..9789dbadee 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1836,6 +1836,8 @@ bool OpGradientDoesntRequireOutputIndices(
{"SoftplusGrad", {true, {}}},
{"Softsign", {true, {}}},
{"ReluGrad", {true, {}}},
+ {"LeakyRelu", {true, {}}},
+ {"LeakyReluGrad", {true, {}}},
{"Conv2D", {true, {}}},
{"DepthwiseConv2dNative", {true, {}}},
{"Dilation2D", {true, {}}},
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index a45a325b47..672d6556f5 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -282,6 +283,125 @@ class Relu6Test(test.TestCase):
self.assertLess(err, 1e-10)
+class LeakyReluTest(test.TestCase):
+
+ def _npLeakyRelu(self, np_features, alpha=0.1):
+ return np.maximum(np_features, alpha * np_features)
+
+ def testNpLeakyRelu(self):
+ self.assertAllClose(
+ np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
+ [0.1, -0.03, 0.5, -0.07, 0.9]]),
+ self._npLeakyRelu(
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]]),
+ alpha=0.1))
+
+ def _testLeakyRelu(self, np_features, alpha, use_gpu=False):
+ np_leaky_relu = self._npLeakyRelu(np_features, alpha)
+ with self.test_session(use_gpu=use_gpu):
+ leaky_relu = nn_ops.leaky_relu(np_features, alpha)
+ tf_leaky_relu = leaky_relu.eval()
+ self.assertAllClose(np_leaky_relu, tf_leaky_relu)
+ self.assertShapeEqual(np_leaky_relu, leaky_relu)
+
+ def testNumbers(self):
+ for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.2,
+ use_gpu=False)
+ if t in [np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.1,
+ use_gpu=True)
+
+ # The gradient test for Leaky ReLU is a bit tricky as the derivative is not
+ # well defined at around zero and we want to avoid that in terms of input
+ # values.
+ def testGradientFloat32(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradientFloat64(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.2, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradGradFloat32(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradGradFloat64(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradientScalar(self):
+ with self.test_session() as sess:
+ x = variables.Variable(-100.)
+ y = nn_ops.leaky_relu(x, 0.05)
+ loss = y**2
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2)
+ train_op = optimizer.minimize(loss)
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ self.assertAllClose(x.eval(), -99.9)
+
+
class EluTest(test.TestCase):
def _npElu(self, np_features):
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index e1a01ab4c3..902653befc 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -389,6 +389,21 @@ def _Relu6GradGrad(op, grad):
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+@ops.RegisterGradient("LeakyRelu")
+def _LeakyReluGrad(op, grad):
+ x = op.inputs[0]
+ alpha = op.get_attr("alpha")
+ return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
+
+
+@ops.RegisterGradient("LeakyReluGrad")
+def _LeakyReluGradGrad(op, grad):
+ x = op.inputs[1]
+ alpha = op.get_attr("alpha")
+ return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+
+
@ops.RegisterGradient("Elu")
def _EluGrad(op, grad):
return gen_nn_ops.elu_grad(grad, op.outputs[0])
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 1fbe31a098..04962da7f7 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,6 +22,7 @@ import numbers
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1602,6 +1603,8 @@ def leaky_relu(features, alpha=0.2, name=None):
features = ops.convert_to_tensor(features, name="features")
if features.dtype.is_integer:
features = math_ops.to_float(features)
+ if compat.forward_compatible(2018, 11, 1):
+ return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
return math_ops.maximum(alpha * features, features, name=name)