aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 11:29:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 11:29:04 -0700
commit96237f7b7ae6b7b8a2cbcf6d64312906b96f060b (patch)
treea96bb853e59dc37e90e4f8fde229f4d88b3f225a
parent3f0155133d668cf6cee1f1fb362d2a75c04836e3 (diff)
parent96eec07af06f4dfc75cee57b74ba4b5347619634 (diff)
Merge pull request #21658 from lowintelligence:master
PiperOrigin-RevId: 216217509
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc27
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt24
-rw-r--r--tensorflow/core/kernels/relu_op.cc153
-rw-r--r--tensorflow/core/kernels/relu_op.h61
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h30
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/ops/nn_ops.cc15
-rw-r--r--tensorflow/core/ops/ops.pbtxt68
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc2
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py120
-rw-r--r--tensorflow/python/ops/nn_grad.py15
-rw-r--r--tensorflow/python/ops/nn_ops.py3
15 files changed, 500 insertions, 69 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index b587e63227..9d2208d84d 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -411,6 +411,7 @@ tf_cc_test(
srcs = ["gradients/nn_grad_test.cc"],
deps = [
":cc_ops",
+ ":cc_ops_internal",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index 588e96cb19..2a32a2ed6f 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -143,6 +143,33 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
+Status LeakyReluGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ float alpha;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
+ internal::LeakyReluGrad::Attrs attrs;
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0),
+ attrs.Alpha(alpha));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper);
+
+Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ float alpha;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
+ internal::LeakyReluGrad::Attrs attrs;
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1),
+ attrs.Alpha(alpha));
+ grad_outputs->push_back(dx);
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper);
+
Status EluGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index aa72cf7ba2..f5a09e09dc 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
+#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -160,6 +161,32 @@ TEST_F(NNGradTest, Relu6Grad) {
RunTest(x, x_init_value, y, shape);
}
+TEST_F(NNGradTest, LeakyReluGrad) {
+ TensorShape shape({5, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ auto y = ops::internal::LeakyRelu(scope_, x);
+ // Avoid input values where Leaky ReLU gradient is not well defined (around
+ // zero).
+ Tensor x_init_value = test::AsTensor<float>(
+ {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
+ {5, 2});
+ RunTest(x, x_init_value, y, shape);
+}
+
+TEST_F(NNGradTest, LeakyReluGradGrad) {
+ TensorShape shape({5, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ // Avoid input values where Leaky ReLU gradient is not well defined (around
+ // zero).
+ Tensor x_init_value = test::AsTensor<float>(
+ {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2});
+ Tensor features = test::AsTensor<float>(
+ {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
+ {5, 2});
+ auto y = ops::internal::LeakyReluGrad(scope_, x, features);
+ RunTest(x, x_init_value, y, shape);
+}
+
TEST_F(NNGradTest, EluGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
new file mode 100644
index 0000000000..280148e032
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "LeakyRelu"
+ visibility: HIDDEN
+ summary: "Computes rectified linear: `max(features, features * alpha)`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
new file mode 100644
index 0000000000..e427526602
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "LeakyReluGrad"
+ visibility: HIDDEN
+ in_arg {
+ name: "gradients"
+ description: <<END
+The backpropagated gradients to the corresponding LeakyRelu operation.
+END
+ }
+ in_arg {
+ name: "features"
+ description: <<END
+The features passed as input to the corresponding LeakyRelu operation,
+OR the outputs of that operation (both work equivalently).
+END
+ }
+ out_arg {
+ name: "backprops"
+ description: <<END
+`gradients * (features > 0) + alpha * gradients * (featurs <= 0)`.
+END
+ }
+ summary: "Computes rectified linear gradients for a LeakyRelu operation."
+}
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index 173fea37ed..e67695d54a 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
-#define REGISTER_RELU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluGradOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6Op<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6GradOp<CPUDevice, type>)
+#define REGISTER_RELU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6Op<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<CPUDevice, type>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
@@ -99,6 +105,19 @@ namespace functor {
extern template struct Relu6Grad<GPUDevice, T>; \
\
template <> \
+ void LeakyRelu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct LeakyRelu<GPUDevice, T>; \
+ \
+ template <> \
+ void LeakyReluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct LeakyReluGrad<GPUDevice, T>; \
+ \
+ template <> \
void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
@@ -134,30 +153,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
// Registration of the GPU implementations.
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6Op<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6GradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- SeluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6Op<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SeluGradOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
@@ -188,30 +213,36 @@ REGISTER_KERNEL_BUILDER(
#ifdef TENSORFLOW_USE_SYCL
// Registration of the GPU implementations.
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6Op<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6GradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- SeluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6Op<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6GradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SeluGradOp<SYCLDevice, type>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 4775deeb61..a4638c70c2 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
}
template <typename Device, typename T>
+class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
+ public:
+ explicit LeakyReluOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::LeakyRelu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
+ output->flat<T>());
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+class LeakyReluGradOp
+ : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
+ public:
+ explicit LeakyReluGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, T alpha, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
+ // outputs (using either one yields the same result here).
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, alpha_, output);
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a, T alpha,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::LeakyReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
+ output->flat<T>());
+};
+
+template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index e564da335a..f917142a12 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -91,6 +91,36 @@ struct Relu6Grad {
}
};
+// Functor used by LeakyReluOp to do the computations.
+template <typename Device, typename T>
+struct LeakyRelu {
+ // Computes LeakyRelu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ T alpha, typename TTypes<T>::Tensor activations) {
+ activations.device(d) = features.cwiseMax(features * alpha);
+ }
+};
+
+// Functor used by LeakyReluGradOp to do the computations.
+template <typename Device, typename T>
+struct LeakyReluGrad {
+ // Computes LeakyReluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the LeakyRelu op.
+ // features: either the inputs that were passed to the LeakyRelu or, or its
+ // outputs (using either one yields the same result here).
+ // backprops: gradients to backpropagate to the LeakyRelu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features, T alpha,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ (features > static_cast<T>(0)).select(gradients, gradients * alpha);
+ }
+};
+
// Functor used by EluOp to do the computations.
template <typename Device, typename T>
struct Elu {
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index b9391517c1..dd5f9495e2 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -145,14 +145,16 @@ struct Relu<Device, qint8> {
} // namespace functor
// Definition of the GPU implementations declared in relu_op.cc.
-#define DEFINE_GPU_KERNELS(T) \
- template struct functor::Relu<GPUDevice, T>; \
- template struct functor::ReluGrad<GPUDevice, T>; \
- template struct functor::Relu6<GPUDevice, T>; \
- template struct functor::Relu6Grad<GPUDevice, T>; \
- template struct functor::Elu<GPUDevice, T>; \
- template struct functor::EluGrad<GPUDevice, T>; \
- template struct functor::Selu<GPUDevice, T>; \
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Relu<GPUDevice, T>; \
+ template struct functor::ReluGrad<GPUDevice, T>; \
+ template struct functor::Relu6<GPUDevice, T>; \
+ template struct functor::Relu6Grad<GPUDevice, T>; \
+ template struct functor::LeakyRelu<GPUDevice, T>; \
+ template struct functor::LeakyReluGrad<GPUDevice, T>; \
+ template struct functor::Elu<GPUDevice, T>; \
+ template struct functor::EluGrad<GPUDevice, T>; \
+ template struct functor::Selu<GPUDevice, T>; \
template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index d1d81b27cc..a9ca69ad86 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+REGISTER_OP("LeakyRelu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("LeakyReluGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 14cc9df9a2..2048ad26ac 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14296,6 +14296,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6d3ef9a37b..9789dbadee 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1836,6 +1836,8 @@ bool OpGradientDoesntRequireOutputIndices(
{"SoftplusGrad", {true, {}}},
{"Softsign", {true, {}}},
{"ReluGrad", {true, {}}},
+ {"LeakyRelu", {true, {}}},
+ {"LeakyReluGrad", {true, {}}},
{"Conv2D", {true, {}}},
{"DepthwiseConv2dNative", {true, {}}},
{"Dilation2D", {true, {}}},
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index a45a325b47..672d6556f5 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -282,6 +283,125 @@ class Relu6Test(test.TestCase):
self.assertLess(err, 1e-10)
+class LeakyReluTest(test.TestCase):
+
+ def _npLeakyRelu(self, np_features, alpha=0.1):
+ return np.maximum(np_features, alpha * np_features)
+
+ def testNpLeakyRelu(self):
+ self.assertAllClose(
+ np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
+ [0.1, -0.03, 0.5, -0.07, 0.9]]),
+ self._npLeakyRelu(
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]]),
+ alpha=0.1))
+
+ def _testLeakyRelu(self, np_features, alpha, use_gpu=False):
+ np_leaky_relu = self._npLeakyRelu(np_features, alpha)
+ with self.test_session(use_gpu=use_gpu):
+ leaky_relu = nn_ops.leaky_relu(np_features, alpha)
+ tf_leaky_relu = leaky_relu.eval()
+ self.assertAllClose(np_leaky_relu, tf_leaky_relu)
+ self.assertShapeEqual(np_leaky_relu, leaky_relu)
+
+ def testNumbers(self):
+ for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.2,
+ use_gpu=False)
+ if t in [np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.1,
+ use_gpu=True)
+
+ # The gradient test for Leaky ReLU is a bit tricky as the derivative is not
+ # well defined at around zero and we want to avoid that in terms of input
+ # values.
+ def testGradientFloat32(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradientFloat64(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.2, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradGradFloat32(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradGradFloat64(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradientScalar(self):
+ with self.test_session() as sess:
+ x = variables.Variable(-100.)
+ y = nn_ops.leaky_relu(x, 0.05)
+ loss = y**2
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2)
+ train_op = optimizer.minimize(loss)
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ self.assertAllClose(x.eval(), -99.9)
+
+
class EluTest(test.TestCase):
def _npElu(self, np_features):
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index e1a01ab4c3..902653befc 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -389,6 +389,21 @@ def _Relu6GradGrad(op, grad):
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+@ops.RegisterGradient("LeakyRelu")
+def _LeakyReluGrad(op, grad):
+ x = op.inputs[0]
+ alpha = op.get_attr("alpha")
+ return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
+
+
+@ops.RegisterGradient("LeakyReluGrad")
+def _LeakyReluGradGrad(op, grad):
+ x = op.inputs[1]
+ alpha = op.get_attr("alpha")
+ return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+
+
@ops.RegisterGradient("Elu")
def _EluGrad(op, grad):
return gen_nn_ops.elu_grad(grad, op.outputs[0])
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 1fbe31a098..04962da7f7 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,6 +22,7 @@ import numbers
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1602,6 +1603,8 @@ def leaky_relu(features, alpha=0.2, name=None):
features = ops.convert_to_tensor(features, name="features")
if features.dtype.is_integer:
features = math_ops.to_float(features)
+ if compat.forward_compatible(2018, 11, 1):
+ return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
return math_ops.maximum(alpha * features, features, name=name)