From aa25cc078c9b55e5ca3e0f59df43e169bfee8f3c Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Thu, 16 Aug 2018 19:04:37 +0800 Subject: Add LeakyRelu C++ Op and its gradient implementation. LeakyRelu, defined as 'y = { x (x>=0) or alpha*x (x<0) }', was computed by combined Ops 'max(x, alpha*x)' in current codes. Hence its gradient calculation for back propagation would contain a serial of element-wise Ops. This looks really unnecessary for such a simple op and it could be done within just one Op with less memory accesses. --- tensorflow/cc/gradients/nn_grad.cc | 13 +++ tensorflow/cc/gradients/nn_grad_test.cc | 13 +++ tensorflow/core/kernels/relu_op.cc | 153 +++++++++++++++---------- tensorflow/core/kernels/relu_op.h | 59 ++++++++++ tensorflow/core/kernels/relu_op_functor.h | 31 +++++ tensorflow/core/kernels/relu_op_gpu.cu.cc | 18 +-- tensorflow/core/ops/nn_ops.cc | 15 +++ tensorflow/core/ops/ops.pbtxt | 68 +++++++++++ tensorflow/python/eager/pywrap_tfe_src.cc | 2 + tensorflow/python/kernel_tests/relu_op_test.py | 113 ++++++++++++++++++ tensorflow/python/ops/nn_grad.py | 15 +++ tensorflow/python/ops/nn_ops.py | 3 +- 12 files changed, 432 insertions(+), 71 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 588e96cb19..0fc23d0bf7 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -143,6 +143,19 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper); +Status LeakyReluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + attrs.Alpha(alpha); + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index aa72cf7ba2..5ebece7b6e 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -41,6 +41,7 @@ using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; +using ops::LeakyRelu; using ops::Selu; using ops::Softmax; using ops::Softplus; @@ -160,6 +161,18 @@ TEST_F(NNGradTest, Relu6Grad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = LeakyRelu(scope_, x); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index d52358737f..c4f2ef5632 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -#define REGISTER_RELU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_CPU).TypeConstraint("T"), \ - ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ - ReluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_CPU).TypeConstraint("T"), \ - Relu6Op); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint("T"), \ - Relu6GradOp) +#define REGISTER_RELU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_CPU).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint("T"), \ + Relu6GradOp) \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint("T"), \ + LeakyReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + LeakyReluGradOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); #undef REGISTER_RELU_KERNELS @@ -99,6 +105,19 @@ namespace functor { extern template struct Relu6Grad; \ \ template <> \ + void LeakyRelu::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ + T alpha, typename TTypes::Tensor activations); \ + extern template struct LeakyRelu; \ + \ + template <> \ + void LeakyReluGrad::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor gradients, \ + typename TTypes::ConstTensor features, \ + T alpha, typename TTypes::Tensor backprops); \ + extern template struct LeakyReluGrad; \ + \ + template <> \ void Elu::operator()(const GPUDevice& d, \ typename TTypes::ConstTensor features, \ typename TTypes::Tensor activations); \ @@ -128,30 +147,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_GPU).TypeConstraint("T"), \ - ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - ReluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_GPU).TypeConstraint("T"), \ - Relu6Op); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint("T"), \ - Relu6GradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_GPU).TypeConstraint("T"), \ - EluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - EluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_GPU).TypeConstraint("T"), \ - SeluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_GPU).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_GPU).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint("T"), \ + Relu6GradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint("T"), \ + LeakyReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + LeakyReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_GPU).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + EluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_GPU).TypeConstraint("T"), \ + SeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ SeluGradOp) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); @@ -161,30 +186,36 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #ifdef TENSORFLOW_USE_SYCL // Registration of the GPU implementations. -#define REGISTER_SYCL_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu").Device(DEVICE_SYCL).TypeConstraint("T"), \ - ReluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - ReluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6").Device(DEVICE_SYCL).TypeConstraint("T"), \ - Relu6Op); \ - REGISTER_KERNEL_BUILDER( \ - Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - Relu6GradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Elu").Device(DEVICE_SYCL).TypeConstraint("T"), \ - EluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - EluGradOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Selu").Device(DEVICE_SYCL).TypeConstraint("T"), \ - SeluOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_SYCL).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + Relu6GradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + LeakyReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + LeakyReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + EluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Selu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + SeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ SeluGradOp) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index e712b02bd7..c55190065c 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -131,6 +131,65 @@ void Relu6GradOp::OperateNoTemplate(OpKernelContext* context, output->flat()); } +template +class LeakyReluOp : public UnaryElementWiseOp> { + public: + explicit LeakyReluOp(OpKernelConstruction* context) + : UnaryElementWiseOp>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::LeakyRelu functor; + functor(context->eigen_device(), input.flat(), + alpha_, output->flat()); + } + + private: + T alpha_; +}; + +template +class LeakyReluGradOp + : public BinaryElementWiseOp> { + public: + explicit LeakyReluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T alpha, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): either the inputs that were passed to LeakyReluOp(), or its + // outputs (using either one yields the same result here). + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, alpha_, output); + } + + private: + T alpha_; +}; + +template +void LeakyReluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, T alpha, Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::LeakyReluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), alpha, + output->flat()); +}; + template class EluOp : public UnaryElementWiseOp> { public: diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 3bc5ba8a50..7f0951451d 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,6 +91,37 @@ struct Relu6Grad { } }; + +// Functor used by LeakyReluOp to do the computations. +template +struct LeakyRelu { + // Computes LeakyRelu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + T alpha, typename TTypes::Tensor activations) { + activations.device(d) = features.cwiseMax(features * alpha); + } +}; + +// Functor used by LeakyReluGradOp to do the computations. +template +struct LeakyReluGrad { + // Computes LeakyReluGrad backprops. + // + // gradients: gradients backpropagated to the LeakyRelu op. + // features: either the inputs that were passed to the LeakyRelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the LeakyRelu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, T alpha, + typename TTypes::Tensor backprops) { + backprops.device(d) = + (features > static_cast(0)).select(gradients, gradients * alpha); + } +}; + // Functor used by EluOp to do the computations. template struct Elu { diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index 089ca8ed27..4452f4dcc9 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -114,14 +114,16 @@ struct ReluGrad { } // namespace functor // Definition of the GPU implementations declared in relu_op.cc. -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::Relu; \ - template struct functor::ReluGrad; \ - template struct functor::Relu6; \ - template struct functor::Relu6Grad; \ - template struct functor::Elu; \ - template struct functor::EluGrad; \ - template struct functor::Selu; \ +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::Relu; \ + template struct functor::ReluGrad; \ + template struct functor::Relu6; \ + template struct functor::Relu6Grad; \ + template struct functor::LeakyRelu; \ + template struct functor::LeakyReluGrad; \ + template struct functor::Elu; \ + template struct functor::EluGrad; \ + template struct functor::Selu; \ template struct functor::SeluGrad; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index e0f25fb4ef..023f988f80 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad") .Attr("T: realnumbertype") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("LeakyRelu") + .Input("features: T") + .Output("activations: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("LeakyReluGrad") + .Input("gradients: T") + .Input("features: T") + .Output("backprops: T") + .Attr("alpha: float = 0.2") + .Attr("T: {half, float, double} = DT_FLOAT") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn); + REGISTER_OP("Elu") .Input("features: T") .Output("activations: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index f2595279e0..837e91bc23 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -13604,6 +13604,74 @@ op { minimum: 1 } } +op { + name: "LeakyRelu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { + name: "LeakykReluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} op { name: "LearnedUnigramCandidateSampler" input_arg { diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 2d54555cd3..9b3b5fd7aa 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1730,6 +1730,7 @@ bool OpDoesntRequireOutput(const string& op_name) { "SoftplusGrad", "Softsign", "ReluGrad", + "LeakyReluGrad", "Conv2D", "DepthwiseConv2dNative", "Dilation2D", @@ -1799,6 +1800,7 @@ bool OpDoesntRequireInput(const string& op_name) { "BiasAdd", "Relu", "Relu6", + "LeakyRelu", "Elu", "Selu", "SparseSoftmaxCrossEntropyWithLogits", diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 25e947f09e..ccb3a231bb 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -252,6 +252,119 @@ class Relu6Test(test.TestCase): self.assertLess(err, 1e-10) +class LeakyReluTest(test.TestCase): + + def _npLeakyRelu(self, np_features, alpha=0.1): + return np.maximum(np_features, alpha * np_features) + + def testNpLeakyRelu(self): + self.assertAllClose( + np.array([[-0.09, 0.7, -0.05, 0.3, -0.01], + [0.1, -0.03, 0.5, -0.07, 0.9]]), + self._npLeakyRelu( + np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9] + ]), alpha=0.1)) + + def _testLeakyRelu(self, np_features, alpha, use_gpu=False): + np_leaky_relu = self._npLeakyRelu(np_features, alpha) + with self.test_session(use_gpu=use_gpu): + leaky_relu = nn_ops.leaky_relu(np_features, alpha) + tf_leaky_relu = leaky_relu.eval() + self.assertAllClose(np_leaky_relu, tf_leaky_relu) + self.assertShapeEqual(np_leaky_relu, leaky_relu) + + def testNumbers(self): + for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: + self._testLeakyRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + alpha=0.2, use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._testLeakyRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + alpha=0.1, use_gpu=True) + + # The gradient test for ReLU is a bit tricky as the derivative is not well + # defined at around zero and we want to avoid that in terms of input values. + def testGradientFloat32(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradientFloat64(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.2, name="leaky_relu") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient err = ", err) + self.assertLess(err, 1e-10) + + def testGradGradFloat32(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) + + def testGradGradFloat64(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-10) + + def testGradientScalar(self): + with self.test_session() as sess: + x = variables.Variable(-100.) + y = nn_ops.leaky_relu(x, 0.05) + loss = y**2 + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2) + train_op = optimizer.minimize(loss) + sess.run(variables.global_variables_initializer()) + sess.run(train_op) + self.assertAllClose(x.eval(), -99.9) + + class EluTest(test.TestCase): def _npElu(self, np_features): diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index df23ac55ce..c2dd58bdf0 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -390,6 +390,21 @@ def _Relu6GradGrad(op, grad): array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) +@ops.RegisterGradient("LeakyRelu") +def _LeakyReluGrad(op, grad): + x = op.inputs[0] + alpha = op.get_attr("alpha") + return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) + + +@ops.RegisterGradient("LeakyReluGrad") +def _LeakyReluGradGrad(op, grad): + x = op.inputs[1] + alpha = op.get_attr("alpha") + return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + + @ops.RegisterGradient("Elu") def _EluGrad(op, grad): return gen_nn_ops.elu_grad(grad, op.outputs[0]) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 6fd1273687..31b8f3945d 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1601,8 +1601,7 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) - alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") - return math_ops.maximum(alpha * features, features, name=name) + return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) def _flatten_outer_dims(logits): -- cgit v1.2.3 From cb5c61a3e11a37fb39a246aaf8ed6d02dd9ae9ab Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Fri, 24 Aug 2018 11:51:34 +0800 Subject: Refine LeakyRelu codes and update APIs. --- .../core/api_def/base_api/api_def_LeakyRelu.pbtxt | 4 ++++ .../api_def/base_api/api_def_LeakyReluGrad.pbtxt | 24 ++++++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 2 +- tensorflow/python/eager/pywrap_tfe_src.cc | 2 +- 4 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt new file mode 100644 index 0000000000..4a61889f54 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "LeakyRelu" + summary: "Computes rectified linear: `max(features, features * alpha)`." +} diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt new file mode 100644 index 0000000000..e427526602 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt @@ -0,0 +1,24 @@ +op { + graph_op_name: "LeakyReluGrad" + visibility: HIDDEN + in_arg { + name: "gradients" + description: < 0) + alpha * gradients * (featurs <= 0)`. +END + } + summary: "Computes rectified linear gradients for a LeakyRelu operation." +} diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 837e91bc23..7693c2d485 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -13637,7 +13637,7 @@ op { } } op { - name: "LeakykReluGrad" + name: "LeakyReluGrad" input_arg { name: "gradients" type_attr: "T" diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 9b3b5fd7aa..18fafd0de1 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1730,6 +1730,7 @@ bool OpDoesntRequireOutput(const string& op_name) { "SoftplusGrad", "Softsign", "ReluGrad", + "LeakyRelu", "LeakyReluGrad", "Conv2D", "DepthwiseConv2dNative", @@ -1800,7 +1801,6 @@ bool OpDoesntRequireInput(const string& op_name) { "BiasAdd", "Relu", "Relu6", - "LeakyRelu", "Elu", "Selu", "SparseSoftmaxCrossEntropyWithLogits", -- cgit v1.2.3 From 4e72dd865a3fc83baa69f6b7c08720a1b546a464 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Wed, 29 Aug 2018 17:05:43 +0800 Subject: Refine LeakyRelu codes. 1. Add C++ gradient of gradient definition of LeakyReLu and revalant UT. 2. Using forward compatibility layer for python code changes. --- tensorflow/cc/gradients/nn_grad.cc | 18 ++++++- tensorflow/cc/gradients/nn_grad_test.cc | 16 ++++++ tensorflow/python/kernel_tests/relu_op_test.py | 70 ++++++++++++++------------ tensorflow/python/ops/nn_ops.py | 5 +- 4 files changed, 73 insertions(+), 36 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0fc23d0bf7..2a32a2ed6f 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -149,13 +149,27 @@ Status LeakyReluGradHelper(const Scope& scope, const Operation& op, float alpha; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); internal::LeakyReluGrad::Attrs attrs; - attrs.Alpha(alpha); - auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), attrs); + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); grad_outputs->push_back(dx); return scope.status(); } REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 5ebece7b6e..bf0db1f59d 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -173,6 +174,21 @@ TEST_F(NNGradTest, LeakyReluGrad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, + {5, 2}); + Tensor features = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index ccb3a231bb..7066f28883 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -283,8 +284,9 @@ class LeakyReluTest(test.TestCase): np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), alpha=0.1, use_gpu=True) - # The gradient test for ReLU is a bit tricky as the derivative is not well - # defined at around zero and we want to avoid that in terms of input values. + # The gradient test for Leaky ReLU is a bit tricky as the derivative is not + # well defined at around zero and we want to avoid that in terms of input + # values. def testGradientFloat32(self): with self.test_session(): x = constant_op.constant( @@ -319,39 +321,41 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradGradFloat32(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - name="x") - y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float32, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float32) gradient of gradient err = ", err) - self.assertLess(err, 1e-4) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float32) gradient of gradient err = ", err) + self.assertLess(err, 1e-4) def testGradGradFloat64(self): - with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - dtype=dtypes.float64, - name="x") - y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float64, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) - print("leaky_relu (float64) gradient of gradient err = ", err) - self.assertLess(err, 1e-10) + with compat.forward_compatibility_horizon(2018, 10, 2): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) + print("leaky_relu (float64) gradient of gradient err = ", err) + self.assertLess(err, 1e-10) def testGradientScalar(self): with self.test_session() as sess: diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 31b8f3945d..52ea202636 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1601,7 +1601,10 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) - return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) + if compat.forward_compatible(2018, 10, 1): + return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) + alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") + return math_ops.maximum(alpha * features, features, name=name) def _flatten_outer_dims(logits): -- cgit v1.2.3 From 2586eb3bfeeef3af357e438ae5aff92d2bac12a5 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Mon, 3 Sep 2018 11:48:35 +0800 Subject: Code fix against ci_build error results. --- tensorflow/cc/gradients/nn_grad_test.cc | 3 +- tensorflow/core/kernels/relu_op.cc | 8 ++-- tensorflow/core/kernels/relu_op.h | 8 ++-- tensorflow/core/kernels/relu_op_functor.h | 1 - tensorflow/python/kernel_tests/relu_op_test.py | 50 ++++++++++++------------- tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ++ 6 files changed, 39 insertions(+), 35 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index bf0db1f59d..d8c2a1a0fc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -180,8 +180,7 @@ TEST_F(NNGradTest, LeakyReluGradGrad) { // Avoid input values where Leaky ReLU gradient is not well defined (around // zero). Tensor x_init_value = test::AsTensor( - {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, - {5, 2}); + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2}); Tensor features = test::AsTensor( {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, {5, 2}); diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index c4f2ef5632..cafa49cbb6 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -106,15 +106,15 @@ namespace functor { \ template <> \ void LeakyRelu::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor features, \ - T alpha, typename TTypes::Tensor activations); \ + const GPUDevice& d, typename TTypes::ConstTensor features, T alpha, \ + typename TTypes::Tensor activations); \ extern template struct LeakyRelu; \ \ template <> \ void LeakyReluGrad::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor gradients, \ - typename TTypes::ConstTensor features, \ - T alpha, typename TTypes::Tensor backprops); \ + typename TTypes::ConstTensor features, T alpha, \ + typename TTypes::Tensor backprops); \ extern template struct LeakyReluGrad; \ \ template <> \ diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index c55190065c..fa79ab03ae 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -143,8 +143,8 @@ class LeakyReluOp : public UnaryElementWiseOp> { void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::LeakyRelu functor; - functor(context->eigen_device(), input.flat(), - alpha_, output->flat()); + functor(context->eigen_device(), input.flat(), alpha_, + output->flat()); } private: @@ -183,7 +183,9 @@ class LeakyReluGradOp template void LeakyReluGradOp::OperateNoTemplate(OpKernelContext* context, - const Tensor& g, const Tensor& a, T alpha, Tensor* output) { + const Tensor& g, + const Tensor& a, T alpha, + Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::LeakyReluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), alpha, diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 7f0951451d..548d5a277d 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -91,7 +91,6 @@ struct Relu6Grad { } }; - // Functor used by LeakyReluOp to do the computations. template struct LeakyRelu { diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 7066f28883..3e24b8a2c4 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -323,37 +323,37 @@ class LeakyReluTest(test.TestCase): def testGradGradFloat32(self): with compat.forward_compatibility_horizon(2018, 10, 2): with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - name="x") - y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float32, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) print("leaky_relu (float32) gradient of gradient err = ", err) self.assertLess(err, 1e-4) def testGradGradFloat64(self): with compat.forward_compatibility_horizon(2018, 10, 2): with self.test_session(): - x = constant_op.constant( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - shape=[2, 5], - dtype=dtypes.float64, - name="x") - y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") - z = gradients_impl.gradients(y, x) - x_init = np.asarray( - [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], - dtype=np.float64, - order="F") - err = gradient_checker.compute_gradient_error( - x, [2, 5], z[0], [2, 5], x_init_value=x_init) + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + dtype=dtypes.float64, + name="x") + y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu") + z = gradients_impl.gradients(y, x) + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float64, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], z[0], [2, 5], x_init_value=x_init) print("leaky_relu (float64) gradient of gradient err = ", err) self.assertLess(err, 1e-10) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 4de662fe33..9e8d320f06 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1324,6 +1324,10 @@ tf_module { name: "lbeta" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "leaky_relu" + argspec: "args=[\'features\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.2\', \'None\'], " + } member_method { name: "less" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " -- cgit v1.2.3 From d2ad105d2dff3c79d8f49f5fb8ce74c38f424e74 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Mon, 3 Sep 2018 12:10:51 +0800 Subject: Add XLA support for LeakyReluOp. Code contributed by: Meng Chen --- tensorflow/compiler/tests/binary_ops_test.py | 7 +++++ tensorflow/compiler/tests/unary_ops_test.py | 5 ++++ tensorflow/compiler/tf2xla/kernels/relu_op.cc | 42 +++++++++++++++++++++++++++ 3 files changed, 54 insertions(+) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 0aafda7fb4..8941dd4e27 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -178,6 +178,13 @@ class BinaryOpsTest(xla_test.XLATestCase): [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) + self._testBinary( + gen_nn_ops._leaky_relu_grad, + np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), + np.array( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), + expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10], dtype=dtype)) + self._testBinary( gen_nn_ops.softmax_cross_entropy_with_logits, np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 73adb0d243..91f876fa23 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -361,6 +361,11 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.leaky_relu, + np.array([[-1.0, 1.0]], dtype=dtype), + expected=np.array([[-0.2, 1.0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array([1, 2, 3, 4], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index d35777ccb1..ec14735884 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -50,6 +50,24 @@ class Relu6Op : public XlaOpKernel { } }; + +class LeakyReluOp : public XlaOpKernel { + public: + explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + } + // Compute the max of the input x and alpha*x. + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + auto alpha = XlaHelpers::FloatLiteral(builder, input_type(0), + static_cast(alpha_)); + ctx->SetOutput(0, + xla::Max(xla::Mul(alpha, ctx->Input(0)), ctx->Input(0))); + } + private: + float alpha_; +}; + class ReluGradOp : public XlaOpKernel { public: explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -84,10 +102,34 @@ class Relu6GradOp : public XlaOpKernel { } }; +class LeakyReluGradOp : public XlaOpKernel { + public: + explicit LeakyReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + } + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return the alpha * lhs. + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + const TensorShape shape = ctx->InputShape(0); + const auto zero = + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = xla::Gt(ctx->Input(1), zero); + auto alpha = XlaHelpers::FloatLiteral(b, input_type(0), + static_cast(alpha_)); + ctx->SetOutput(0, + xla::Select(pred, ctx->Input(0), xla::Mul(alpha, ctx->Input(0)))); + } + private: + float alpha_; +}; + REGISTER_XLA_OP(Name("Relu"), ReluOp); REGISTER_XLA_OP(Name("Relu6"), Relu6Op); +REGISTER_XLA_OP(Name("LeakyRelu"), LeakyReluOp); REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp); REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp); +REGISTER_XLA_OP(Name("LeakyReluGrad"), LeakyReluGradOp); } // namespace } // namespace tensorflow -- cgit v1.2.3 From a95281ce1b449d8f92a3799ff9c1dbf661b70bc4 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Wed, 5 Sep 2018 09:02:40 +0800 Subject: Avoid golden API file changing. --- tensorflow/cc/gradients/nn_grad_test.cc | 3 +-- tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt | 1 + tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index d8c2a1a0fc..f5a09e09dc 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -42,7 +42,6 @@ using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; -using ops::LeakyRelu; using ops::Selu; using ops::Softmax; using ops::Softplus; @@ -165,7 +164,7 @@ TEST_F(NNGradTest, Relu6Grad) { TEST_F(NNGradTest, LeakyReluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto y = LeakyRelu(scope_, x); + auto y = ops::internal::LeakyRelu(scope_, x); // Avoid input values where Leaky ReLU gradient is not well defined (around // zero). Tensor x_init_value = test::AsTensor( diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt index 4a61889f54..280148e032 100644 --- a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt @@ -1,4 +1,5 @@ op { graph_op_name: "LeakyRelu" + visibility: HIDDEN summary: "Computes rectified linear: `max(features, features * alpha)`." } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 9e8d320f06..4de662fe33 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1324,10 +1324,6 @@ tf_module { name: "lbeta" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "leaky_relu" - argspec: "args=[\'features\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.2\', \'None\'], " - } member_method { name: "less" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " -- cgit v1.2.3 From 90cf7fb7786c8a9c135ef73482856b082e80f61a Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Tue, 11 Sep 2018 12:48:30 +0800 Subject: Fix lint errors and typos. --- tensorflow/compiler/tests/binary_ops_test.py | 9 +++++---- tensorflow/compiler/tf2xla/kernels/relu_op.cc | 14 +++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 8941dd4e27..069e83d083 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -179,11 +179,12 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) self._testBinary( - gen_nn_ops._leaky_relu_grad, + gen_nn_ops.leaky_relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), - np.array( - [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), - expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10], dtype=dtype)) + np.array([-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + dtype=dtype), + expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10], + dtype=dtype)) self._testBinary( gen_nn_ops.softmax_cross_entropy_with_logits, diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index ec14735884..8d65e0339c 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -50,7 +50,6 @@ class Relu6Op : public XlaOpKernel { } }; - class LeakyReluOp : public XlaOpKernel { public: explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -61,9 +60,9 @@ class LeakyReluOp : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); auto alpha = XlaHelpers::FloatLiteral(builder, input_type(0), static_cast(alpha_)); - ctx->SetOutput(0, - xla::Max(xla::Mul(alpha, ctx->Input(0)), ctx->Input(0))); + ctx->SetOutput(0, xla::Max(xla::Mul(alpha, ctx->Input(0)), ctx->Input(0))); } + private: float alpha_; }; @@ -115,11 +114,12 @@ class LeakyReluGradOp : public XlaOpKernel { const auto zero = xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); const auto pred = xla::Gt(ctx->Input(1), zero); - auto alpha = XlaHelpers::FloatLiteral(b, input_type(0), - static_cast(alpha_)); - ctx->SetOutput(0, - xla::Select(pred, ctx->Input(0), xla::Mul(alpha, ctx->Input(0)))); + auto alpha = + XlaHelpers::FloatLiteral(b, input_type(0), static_cast(alpha_)); + ctx->SetOutput( + 0, xla::Select(pred, ctx->Input(0), xla::Mul(alpha, ctx->Input(0)))); } + private: float alpha_; }; -- cgit v1.2.3 From 5e9a9547f907599f6954fc5e28b7a78acf3b54eb Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Wed, 12 Sep 2018 11:02:12 +0800 Subject: Revert "Add XLA support for LeakyReluOp." This reverts commit d2ad105d2dff3c79d8f49f5fb8ce74c38f424e74. Since bfloat16 was not supported by LeakyRelu, but it should be supported in XLA Ops. --- tensorflow/compiler/tests/binary_ops_test.py | 8 ----- tensorflow/compiler/tests/unary_ops_test.py | 5 ---- tensorflow/compiler/tf2xla/kernels/relu_op.cc | 42 --------------------------- 3 files changed, 55 deletions(-) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index c478ff4eea..17280e445b 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -178,14 +178,6 @@ class BinaryOpsTest(xla_test.XLATestCase): [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) - self._testBinary( - gen_nn_ops.leaky_relu_grad, - np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), - np.array([-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], - dtype=dtype), - expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10], - dtype=dtype)) - self._testBinary( gen_nn_ops.softmax_cross_entropy_with_logits, np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index dd29ef34ce..5b0e57f83f 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -361,11 +361,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) - self._assertOpOutputMatchesExpected( - nn_ops.leaky_relu, - np.array([[-1.0, 1.0]], dtype=dtype), - expected=np.array([[-0.2, 1.0]], dtype=dtype)) - self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array([1, 2, 3, 4], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 8d65e0339c..d35777ccb1 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -50,23 +50,6 @@ class Relu6Op : public XlaOpKernel { } }; -class LeakyReluOp : public XlaOpKernel { - public: - explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); - } - // Compute the max of the input x and alpha*x. - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - auto alpha = XlaHelpers::FloatLiteral(builder, input_type(0), - static_cast(alpha_)); - ctx->SetOutput(0, xla::Max(xla::Mul(alpha, ctx->Input(0)), ctx->Input(0))); - } - - private: - float alpha_; -}; - class ReluGradOp : public XlaOpKernel { public: explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -101,35 +84,10 @@ class Relu6GradOp : public XlaOpKernel { } }; -class LeakyReluGradOp : public XlaOpKernel { - public: - explicit LeakyReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); - } - // Return the lhs (incoming gradient) if the rhs (input feature) > 0, - // otherwise return the alpha * lhs. - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - const TensorShape shape = ctx->InputShape(0); - const auto zero = - xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto pred = xla::Gt(ctx->Input(1), zero); - auto alpha = - XlaHelpers::FloatLiteral(b, input_type(0), static_cast(alpha_)); - ctx->SetOutput( - 0, xla::Select(pred, ctx->Input(0), xla::Mul(alpha, ctx->Input(0)))); - } - - private: - float alpha_; -}; - REGISTER_XLA_OP(Name("Relu"), ReluOp); REGISTER_XLA_OP(Name("Relu6"), Relu6Op); -REGISTER_XLA_OP(Name("LeakyRelu"), LeakyReluOp); REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp); REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp); -REGISTER_XLA_OP(Name("LeakyReluGrad"), LeakyReluGradOp); } // namespace } // namespace tensorflow -- cgit v1.2.3 From f0886f7269de900d226455d4831722f6fc94a71b Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Tue, 25 Sep 2018 09:59:17 +0800 Subject: Fix build dependencies in tensorflow/cc/BUILD. --- tensorflow/cc/BUILD | 1 + tensorflow/python/kernel_tests/relu_op_test.py | 4 ++-- tensorflow/python/ops/nn_ops.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index f56521dac0..e99d15f85d 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -410,6 +410,7 @@ tf_cc_test( srcs = ["gradients/nn_grad_test.cc"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":grad_testutil", ":gradient_checker", diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index 86d9c90e83..d97a1613b9 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -351,7 +351,7 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradGradFloat32(self): - with compat.forward_compatibility_horizon(2018, 10, 2): + with compat.forward_compatibility_horizon(2018, 11, 2): with self.test_session(): x = constant_op.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], @@ -369,7 +369,7 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-4) def testGradGradFloat64(self): - with compat.forward_compatibility_horizon(2018, 10, 2): + with compat.forward_compatibility_horizon(2018, 11, 2): with self.test_session(): x = constant_op.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index d646245ce3..2861f40586 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1601,7 +1601,7 @@ def leaky_relu(features, alpha=0.2, name=None): features = ops.convert_to_tensor(features, name="features") if features.dtype.is_integer: features = math_ops.to_float(features) - if compat.forward_compatible(2018, 10, 1): + if compat.forward_compatible(2018, 11, 1): return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") return math_ops.maximum(alpha * features, features, name=name) -- cgit v1.2.3 From 96eec07af06f4dfc75cee57b74ba4b5347619634 Mon Sep 17 00:00:00 2001 From: Cao Zongyan Date: Wed, 26 Sep 2018 13:04:46 +0800 Subject: Re-add compat module for leaky_relu implementation. --- tensorflow/python/ops/nn_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 3f64f0af9a..78e000e458 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -22,6 +22,7 @@ import numbers import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_util -- cgit v1.2.3