diff options
-rw-r--r-- | tensorflow/cc/gradients/nn_grad.cc | 13 | ||||
-rw-r--r-- | tensorflow/cc/gradients/nn_grad_test.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op.cc | 153 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op.h | 59 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op_functor.h | 31 | ||||
-rw-r--r-- | tensorflow/core/kernels/relu_op_gpu.cu.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/ops/ops.pbtxt | 68 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/relu_op_test.py | 113 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 15 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 3 |
12 files changed, 432 insertions, 71 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 588e96cb19..0fc23d0bf7 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -143,6 +143,19 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper); +Status LeakyReluGradHelper(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + attrs.Alpha(alpha); + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index aa72cf7ba2..5ebece7b6e 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -41,6 +41,7 @@ using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; +using ops::LeakyRelu; using ops::Selu; using ops::Softmax; using ops::Softplus; @@ -160,6 +161,18 @@ TEST_F(NNGradTest, Relu6Grad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = LeakyRelu(scope_, x); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor<float>( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index d52358737f..c4f2ef5632 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -#define REGISTER_RELU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - ReluOp<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - ReluGradOp<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - Relu6Op<CPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - Relu6GradOp<CPUDevice, type>) +#define REGISTER_RELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + ReluOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + ReluGradOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + Relu6Op<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + Relu6GradOp<CPUDevice, type>) \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + LeakyReluOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + LeakyReluGradOp<CPUDevice, type>); TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); #undef REGISTER_RELU_KERNELS @@ -99,6 +105,19 @@ namespace functor { extern template struct Relu6Grad<GPUDevice, T>; \ \ template <> \ + void LeakyRelu<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstTensor features, \ + T alpha, typename TTypes<T>::Tensor activations); \ + extern template struct LeakyRelu<GPUDevice, T>; \ + \ + template <> \ + void LeakyReluGrad<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \ + typename TTypes<T>::ConstTensor features, \ + T alpha, typename TTypes<T>::Tensor backprops); \ + extern template struct LeakyReluGrad<GPUDevice, T>; \ + \ + template <> \ void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \ typename TTypes<T>::ConstTensor features, \ typename TTypes<T>::Tensor activations); \ @@ -128,30 +147,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - ReluOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - ReluGradOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - Relu6Op<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - Relu6GradOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - EluOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - EluGradOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - SeluOp<GPUDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + ReluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + ReluGradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + Relu6Op<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + Relu6GradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + LeakyReluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + LeakyReluGradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + EluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + EluGradOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SeluOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ SeluGradOp<GPUDevice, type>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); @@ -161,30 +186,36 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #ifdef TENSORFLOW_USE_SYCL // Registration of the GPU implementations. -#define REGISTER_SYCL_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - ReluOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - ReluGradOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - Relu6Op<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - Relu6GradOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - EluOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - EluGradOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ - SeluOp<SYCLDevice, type>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + ReluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + ReluGradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + Relu6Op<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + Relu6GradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + LeakyReluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + LeakyReluGradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + EluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + EluGradOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + SeluOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ SeluGradOp<SYCLDevice, type>) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index e712b02bd7..c55190065c 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -132,6 +132,65 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, } template <typename Device, typename T> +class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> { + public: + explicit LeakyReluOp(OpKernelConstruction* context) + : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::LeakyRelu<Device, T> functor; + functor(context->eigen_device<Device>(), input.flat<T>(), + alpha_, output->flat<T>()); + } + + private: + T alpha_; +}; + +template <typename Device, typename T> +class LeakyReluGradOp + : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> { + public: + explicit LeakyReluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T alpha, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): either the inputs that were passed to LeakyReluOp(), or its + // outputs (using either one yields the same result here). + // OUTPUT: + // gradients to backprop + template <int NDIMS> + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, alpha_, output); + } + + private: + T alpha_; +}; + +template <typename Device, typename T> +void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, T alpha, Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::LeakyReluGrad<Device, T> functor; + functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha, + output->flat<T>()); +}; + +template <typename Device, typename T> class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> { public: using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp; diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 3bc5ba8a50..7f0951451d 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,6 +91,37 @@ struct Relu6Grad { } }; + +// Functor used by LeakyReluOp to do the computations. +template <typename Device, typename T> +struct LeakyRelu { + // Computes LeakyRelu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes<T>::ConstTensor features, + T alpha, typename TTypes<T>::Tensor activations) { + activations.device(d) = features.cwiseMax(features * alpha); + } +}; + +// Functor used by LeakyReluGradOp to do the computations. +template <typename Device, typename T> +struct LeakyReluGrad { + // Computes LeakyReluGrad backprops. + // + // gradients: gradients backpropagated to the LeakyRelu op. + // features: either the inputs that were passed to the LeakyRelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the LeakyRelu inputs. + void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients, + typename TTypes<T>::ConstTensor features, T alpha, + typename TTypes<T>::Tensor backprops) { + backprops.device(d) = + (features > static_cast<T>(0)).select(gradients, gradients * alpha); + } +}; + // Functor used by EluOp to do the computations. template <typename Device, typename T> struct Elu { diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index 089ca8ed27..4452f4dcc9 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -114,14 +114,16 @@ struct ReluGrad<Device, Eigen::half> { } // namespace functor // Definition of the GPU implementations declared in relu_op.cc. -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::Relu<GPUDevice, T>; \ - template struct functor::ReluGrad<GPUDevice, T>; \ - template struct functor::Relu6<GPUDevice, T>; \ - template struct functor::Relu6Grad<GPUDevice, T>; \ - template struct functor::Elu<GPUDevice, T>; \ - template struct functor::EluGrad<GPUDevice, T>; \ - template struct functor::Selu<GPUDevice, T>; \ +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Relu<GPUDevice, T>; \ + template struct functor::ReluGrad<GPUDevice, T>; \ + template struct functor::Relu6<GPUDevice, T>; \ + template struct functor::Relu6Grad<GPUDevice, T>; \ + template struct functor::LeakyRelu<GPUDevice, T>; \ + template struct functor::LeakyReluGrad<GPUDevice, T>; \ + template struct functor::Elu<GPUDevice, T>; \ + template struct functor::EluGrad<GPUDevice, T>; \ + template struct functor::Selu<GPUDevice, T>; \ template struct functor::SeluGrad<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index e0f25fb4ef..023f988f80 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad") .Attr("T: realnumbertype") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("LeakyRelu") + .Input("features: T") + .Output("activations: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("LeakyReluGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + REGISTER_OP("Elu") .Input("features: T") .Output("activations: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index f2595279e0..837e91bc23 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -13605,6 +13605,74 @@ op { } } op { + name: "LeakyRelu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { + name: "LeakykReluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "LearnedUnigramCandidateSampler" input_arg { name: "true_classes" diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 2d54555cd3..9b3b5fd7aa 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1730,6 +1730,7 @@ bool OpDoesntRequireOutput(const string& op_name) { "SoftplusGrad", "Softsign", "ReluGrad", + "LeakyReluGrad", "Conv2D", "DepthwiseConv2dNative", "Dilation2D", @@ -1799,6 +1800,7 @@ bool OpDoesntRequireInput(const string& op_name) { "BiasAdd", "Relu", "Relu6", + "LeakyRelu", "Elu", "Selu", "SparseSoftmaxCrossEntropyWithLogits", diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 25e947f09e..ccb3a231bb 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -252,6 +252,119 @@ class Relu6Test(test.TestCase): self.assertLess(err, 1e-10) +class LeakyReluTest(test.TestCase): + + def _npLeakyRelu(self, np_features, alpha=0.1): + return np.maximum(np_features, alpha * np_features) + + def testNpLeakyRelu(self): + self.assertAllClose( + np.array([[-0.09, 0.7, -0.05, 0.3, -0.01], + [0.1, -0.03, 0.5, -0.07, 0.9]]), + self._npLeakyRelu( + np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9] + ]), alpha=0.1)) + + def _testLeakyRelu(self, np_features, alpha, use_gpu=False): + np_leaky_relu = self._npLeakyRelu(np_features, alpha) + with self.test_session(use_gpu=use_gpu): + leaky_relu = nn_ops.leaky_relu(np_features, alpha) + tf_leaky_relu = leaky_relu.eval() + self.assertAllClose(np_leaky_relu, tf_leaky_relu) + self.assertShapeEqual(np_leaky_relu, leaky_relu) + + def testNumbers(self): + for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: + self._testLeakyRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + alpha=0.2, use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._testLeakyRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + alpha=0.1, use_gpu=True) + + # The gradient test for ReLU is a bit tricky as the derivative is not well + # defined at around zero and we want to avoid that in terms of input values. + def testGradientFloat32(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradientFloat64(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.2, name="leaky_relu") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient err = ", err) + self.assertLess(err, 1e-10) + + def testGradGradFloat32(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradGradFloat64(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-10) + + def testGradientScalar(self): + with self.test_session() as sess: + x = variables.Variable(-100.) + y = nn_ops.leaky_relu(x, 0.05) + loss = y**2 + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2) + train_op = optimizer.minimize(loss) + sess.run(variables.global_variables_initializer()) + sess.run(train_op) + self.assertAllClose(x.eval(), -99.9) + + class EluTest(test.TestCase): def _npElu(self, np_features): diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index df23ac55ce..c2dd58bdf0 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -390,6 +390,21 @@ def _Relu6GradGrad(op, grad): array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) +@ops.RegisterGradient("LeakyRelu") +def _LeakyReluGrad(op, grad): + x = op.inputs[0] + alpha = op.get_attr("alpha") + return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) + + +@ops.RegisterGradient("LeakyReluGrad") +def _LeakyReluGradGrad(op, grad): + x = op.inputs[1] + alpha = op.get_attr("alpha") + return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + + @ops.RegisterGradient("Elu") def _EluGrad(op, grad): return gen_nn_ops.elu_grad(grad, op.outputs[0]) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 6fd1273687..31b8f3945d 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1601,8 +1601,7 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) - alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") - return math_ops.maximum(alpha * features, features, name=name) + return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) def _flatten_outer_dims(logits): |