From aa25cc078c9b55e5ca3e0f59df43e169bfee8f3c Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Thu, 16 Aug 2018 19:04:37 +0800 Subject: Add LeakyRelu C++ Op and its gradient implementation. LeakyRelu, defined as 'y = { x (x>=0) or alpha*x (x<0) }', was computed by combined Ops 'max(x, alpha*x)' in current codes. Hence its gradient calculation for back propagation would contain a serial of element-wise Ops. This looks really unnecessary for such a simple op and it could be done within just one Op with less memory accesses. --- tensorflow/core/kernels/relu_op.cc | 153 ++++++++++++++++++------------ tensorflow/core/kernels/relu_op.h | 59 ++++++++++++ tensorflow/core/kernels/relu_op_functor.h | 31 ++++++ tensorflow/core/kernels/relu_op_gpu.cu.cc | 18 ++-- tensorflow/core/ops/nn_ops.cc | 15 +++ tensorflow/core/ops/ops.pbtxt | 68 +++++++++++++ 6 files changed, 275 insertions(+), 69 deletions(-) (limited to 'tensorflow/core') diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index d52358737f..c4f2ef5632 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -#define REGISTER_RELU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_CPU).TypeConstraint("T"), \ - ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ - ReluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_CPU).TypeConstraint("T"), \ - Relu6Op); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint("T"), \ - Relu6GradOp) +#define REGISTER_RELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_CPU).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint("T"), \ + Relu6GradOp) \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint("T"), \ + LeakyReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + LeakyReluGradOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); #undef REGISTER_RELU_KERNELS @@ -99,6 +105,19 @@ namespace functor { extern template struct Relu6Grad; \ \ template <> \ + void LeakyRelu::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + T alpha, typename TTypes::Tensor activations); \ + extern template struct LeakyRelu; \ + \ + template <> \ + void LeakyReluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + T alpha, typename TTypes::Tensor backprops); \ + extern template struct LeakyReluGrad; \ + \ + template <> \ void Elu::operator()(const GPUDevice& d, \ typename TTypes::ConstTensor features, \ typename TTypes::Tensor activations); \ @@ -128,30 +147,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_GPU).TypeConstraint("T"), \ - ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - ReluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_GPU).TypeConstraint("T"), \ - Relu6Op); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint("T"), \ - Relu6GradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_GPU).TypeConstraint("T"), \ - EluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - EluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_GPU).TypeConstraint("T"), \ - SeluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_GPU).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_GPU).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint("T"), \ + Relu6GradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint("T"), \ + LeakyReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + LeakyReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_GPU).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + EluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_GPU).TypeConstraint("T"), \ + SeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ SeluGradOp) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); @@ -161,30 +186,36 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #ifdef TENSORFLOW_USE_SYCL // Registration of the GPU implementations. -#define REGISTER_SYCL_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_SYCL).TypeConstraint("T"), \ - ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - ReluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_SYCL).TypeConstraint("T"), \ - Relu6Op); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - Relu6GradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_SYCL).TypeConstraint("T"), \ - EluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - EluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_SYCL).TypeConstraint("T"), \ - SeluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_SYCL).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + Relu6GradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + LeakyReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + LeakyReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + EluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + SeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ SeluGradOp) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index e712b02bd7..c55190065c 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -131,6 +131,65 @@ void Relu6GradOp::OperateNoTemplate(OpKernelContext* context, output->flat()); } +template +class LeakyReluOp : public UnaryElementWiseOp> { + public: + explicit LeakyReluOp(OpKernelConstruction* context) + : UnaryElementWiseOp>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::LeakyRelu functor; + functor(context->eigen_device(), input.flat(), + alpha_, output->flat()); + } + + private: + T alpha_; +}; + +template +class LeakyReluGradOp + : public BinaryElementWiseOp> { + public: + explicit LeakyReluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T alpha, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): either the inputs that were passed to LeakyReluOp(), or its + // outputs (using either one yields the same result here). + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, alpha_, output); + } + + private: + T alpha_; +}; + +template +void LeakyReluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, T alpha, Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::LeakyReluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), alpha, + output->flat()); +}; + template class EluOp : public UnaryElementWiseOp> { public: diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 3bc5ba8a50..7f0951451d 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,6 +91,37 @@ struct Relu6Grad { } }; + +// Functor used by LeakyReluOp to do the computations. +template +struct LeakyRelu { + // Computes LeakyRelu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + T alpha, typename TTypes::Tensor activations) { + activations.device(d) = features.cwiseMax(features * alpha); + } +}; + +// Functor used by LeakyReluGradOp to do the computations. +template +struct LeakyReluGrad { + // Computes LeakyReluGrad backprops. + // + // gradients: gradients backpropagated to the LeakyRelu op. + // features: either the inputs that were passed to the LeakyRelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the LeakyRelu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, T alpha, + typename TTypes::Tensor backprops) { + backprops.device(d) = + (features > static_cast(0)).select(gradients, gradients * alpha); + } +}; + // Functor used by EluOp to do the computations. template struct Elu { diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index 089ca8ed27..4452f4dcc9 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -114,14 +114,16 @@ struct ReluGrad { } // namespace functor // Definition of the GPU implementations declared in relu_op.cc. -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::Relu; \ - template struct functor::ReluGrad; \ - template struct functor::Relu6; \ - template struct functor::Relu6Grad; \ - template struct functor::Elu; \ - template struct functor::EluGrad; \ - template struct functor::Selu; \ +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Relu; \ + template struct functor::ReluGrad; \ + template struct functor::Relu6; \ + template struct functor::Relu6Grad; \ + template struct functor::LeakyRelu; \ + template struct functor::LeakyReluGrad; \ + template struct functor::Elu; \ + template struct functor::EluGrad; \ + template struct functor::Selu; \ template struct functor::SeluGrad; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index e0f25fb4ef..023f988f80 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad") .Attr("T: realnumbertype") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("LeakyRelu") + .Input("features: T") + .Output("activations: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("LeakyReluGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + REGISTER_OP("Elu") .Input("features: T") .Output("activations: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index f2595279e0..837e91bc23 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -13604,6 +13604,74 @@ op { minimum: 1 } } +op { + name: "LeakyRelu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { + name: "LeakykReluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} op { name: "LearnedUnigramCandidateSampler" input_arg { -- cgit v1.2.3 From cb5c61a3e11a37fb39a246aaf8ed6d02dd9ae9ab Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Fri, 24 Aug 2018 11:51:34 +0800 Subject: Refine LeakyRelu codes and update APIs. --- .../core/api_def/base_api/api_def_LeakyRelu.pbtxt | 4 ++++ .../api_def/base_api/api_def_LeakyReluGrad.pbtxt | 24 ++++++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 2 +- tensorflow/python/eager/pywrap_tfe_src.cc | 2 +- 4 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt (limited to 'tensorflow/core') diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt new file mode 100644 index 0000000000..4a61889f54 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "LeakyRelu" + summary: "Computes rectified linear: `max(features, features * alpha)`." +} diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt new file mode 100644 index 0000000000..e427526602 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt @@ -0,0 +1,24 @@ +op { + graph_op_name: "LeakyReluGrad" + visibility: HIDDEN + in_arg { + name: "gradients" + description: < 0) + alpha * gradients * (featurs <= 0)`. +END + } + summary: "Computes rectified linear gradients for a LeakyRelu operation." +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 837e91bc23..7693c2d485 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -13637,7 +13637,7 @@ op { } } op { - name: "LeakykReluGrad" + name: "LeakyReluGrad" input_arg { name: "gradients" type_attr: "T" diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 9b3b5fd7aa..18fafd0de1 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1730,6 +1730,7 @@ bool OpDoesntRequireOutput(const string& op_name) { "SoftplusGrad", "Softsign", "ReluGrad", + "LeakyRelu", "LeakyReluGrad", "Conv2D", "DepthwiseConv2dNative", @@ -1800,7 +1801,6 @@ bool OpDoesntRequireInput(const string& op_name) { "BiasAdd", "Relu", "Relu6", - "LeakyRelu", "Elu", "Selu", "SparseSoftmaxCrossEntropyWithLogits", -- cgit v1.2.3 From 2586eb3bfeeef3af357e438ae5aff92d2bac12a5 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Mon, 3 Sep 2018 11:48:35 +0800 Subject: Code fix against ci_build error results. --- tensorflow/cc/gradients/nn_grad_test.cc | 3 +- tensorflow/core/kernels/relu_op.cc | 8 ++-- tensorflow/core/kernels/relu_op.h | 8 ++-- tensorflow/core/kernels/relu_op_functor.h | 1 - tensorflow/python/kernel_tests/relu_op_test.py | 50 ++++++++++++------------- tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ++ 6 files changed, 39 insertions(+), 35 deletions(-) (limited to 'tensorflow/core') diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index bf0db1f59d..d8c2a1a0fc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -180,8 +180,7 @@ TEST_F(NNGradTest, LeakyReluGradGrad) { // Avoid input values where Leaky ReLU gradient is not well defined (around // zero). Tensor x_init_value = test::AsTensor( - {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, - {5, 2}); + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2}); Tensor features = test::AsTensor( {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, {5, 2}); diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index c4f2ef5632..cafa49cbb6 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -106,15 +106,15 @@ namespace functor { \ template <> \ void LeakyRelu::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor features, \ - T alpha, typename TTypes::Tensor activations); \ + const GPUDevice& d, typename TTypes::ConstTensor features, T alpha, \ + typename TTypes::Tensor activations); \ extern template struct LeakyRelu; \ \ template <> \ void LeakyReluGrad::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ - T alpha, typename TTypes::Tensor backprops); \ + typename TTypes::ConstTensor features, T alpha, \ + typename TTypes::Tensor backprops); \ extern template struct LeakyReluGrad; \ \ template <> \ diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index c55190065c..fa79ab03ae 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -143,8 +143,8 @@ class LeakyReluOp : public UnaryElementWiseOp> { void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::LeakyRelu functor; - functor(context->eigen_device(), input.flat(), - alpha_, output->flat()); + functor(context->eigen_device(), input.flat(), alpha_, + output->flat()); } private: @@ -183,7 +183,9 @@ class LeakyReluGradOp template void LeakyReluGradOp::OperateNoTemplate(OpKernelContext* context, - const Tensor& g, const Tensor& a, T alpha, Tensor* output) { + const Tensor& g, + const Tensor& a, T alpha, + Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::LeakyReluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), alpha, diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 7f0951451d..548d5a277d 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,7 +91,6 @@ struct Relu6Grad { } }; - // Functor used by LeakyReluOp to do the computations. template struct LeakyRelu { diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 7066f28883..3e24b8a2c4 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -323,37 +323,37 @@ class LeakyReluTest(test.TestCase): def testGradGradFloat32(self): with compat.forward_compatibility_horizon(2018, 10, 2): with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - name="x") - y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float32, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) print("leaky_relu (float32) gradient of gradient err = ", err) self.assertLess(err, 1e-4) def testGradGradFloat64(self): with compat.forward_compatibility_horizon(2018, 10, 2): with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - dtype=dtypes.float64, - name="x") - y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float64, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) print("leaky_relu (float64) gradient of gradient err = ", err) self.assertLess(err, 1e-10) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 4de662fe33..9e8d320f06 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1324,6 +1324,10 @@ tf_module { name: "lbeta" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "leaky_relu" + argspec: "args=[\'features\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.2\', \'None\'], " + } member_method { name: "less" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " -- cgit v1.2.3 From a95281ce1b449d8f92a3799ff9c1dbf661b70bc4 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Wed, 5 Sep 2018 09:02:40 +0800 Subject: Avoid golden API file changing. --- tensorflow/cc/gradients/nn_grad_test.cc | 3 +-- tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt | 1 + tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) (limited to 'tensorflow/core') diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index d8c2a1a0fc..f5a09e09dc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -42,7 +42,6 @@ using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; -using ops::LeakyRelu; using ops::Selu; using ops::Softmax; using ops::Softplus; @@ -165,7 +164,7 @@ TEST_F(NNGradTest, Relu6Grad) { TEST_F(NNGradTest, LeakyReluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto y = LeakyRelu(scope_, x); + auto y = ops::internal::LeakyRelu(scope_, x); // Avoid input values where Leaky ReLU gradient is not well defined (around // zero). Tensor x_init_value = test::AsTensor( diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt index 4a61889f54..280148e032 100644 --- a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt @@ -1,4 +1,5 @@ op { graph_op_name: "LeakyRelu" + visibility: HIDDEN summary: "Computes rectified linear: `max(features, features * alpha)`." } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 9e8d320f06..4de662fe33 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1324,10 +1324,6 @@ tf_module { name: "lbeta" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "leaky_relu" - argspec: "args=[\'features\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.2\', \'None\'], " - } member_method { name: "less" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " -- cgit v1.2.3