/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // See docs in ../ops/nn_ops.cc. #ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_ #define TENSORFLOW_CORE_KERNELS_RELU_OP_H_ #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/relu_op_functor.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { template class ReluOp : public UnaryElementWiseOp> { public: using UnaryElementWiseOp>::UnaryElementWiseOp; void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::Relu functor; functor(context->eigen_device(), input.flat(), output->flat()); } }; // Out of line check to save code space (we have this code once, rather // than once for every NDIMS * NumTypes * Num_different_relu_variants // functions. struct ReluHelpers { static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g, const Tensor& a) { OP_REQUIRES(context, a.IsSameSize(g), errors::InvalidArgument("g and a must be the same size")); } static bool ValidateSameSize(OpKernelContext* context, const Tensor& g, const Tensor& a) { ValidateSameSizeHelper(context, g, a); return context->status().ok(); } }; template class ReluGradOp : public BinaryElementWiseOp> { public: using BinaryElementWiseOp>::BinaryElementWiseOp; void OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output); // INPUTS: // g (gradients): backpropagated gradients // a (inputs): either the inputs that were passed to ReluOp(), or its // outputs (using either one yields the same result here). // OUTPUT: // gradients to backprop template void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { OperateNoTemplate(context, g, a, output); } }; template void ReluGradOp::OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::ReluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), output->flat()); } template class Relu6Op : public UnaryElementWiseOp> { public: using UnaryElementWiseOp>::UnaryElementWiseOp; void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::Relu6 functor; functor(context->eigen_device(), input.flat(), output->flat()); } }; template class Relu6GradOp : public BinaryElementWiseOp> { public: using BinaryElementWiseOp>::BinaryElementWiseOp; void OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output); // INPUTS: // g (gradients): backpropagated gradients // a (inputs): inputs that were passed to Relu6Op() // OUTPUT: // gradients to backprop template void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { OperateNoTemplate(context, g, a, output); } }; template void Relu6GradOp::OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::Relu6Grad functor; functor(context->eigen_device(), g.flat(), a.flat(), output->flat()); } template class LeakyReluOp : public UnaryElementWiseOp> { public: explicit LeakyReluOp(OpKernelConstruction* context) : UnaryElementWiseOp>(context) { float alpha_tmp; OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); alpha_ = T(alpha_tmp); } void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::LeakyRelu functor; functor(context->eigen_device(), input.flat(), alpha_, output->flat()); } private: T alpha_; }; template class LeakyReluGradOp : public BinaryElementWiseOp> { public: explicit LeakyReluGradOp(OpKernelConstruction* context) : BinaryElementWiseOp>(context) { float alpha_tmp; OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); alpha_ = T(alpha_tmp); } void OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, T alpha, Tensor* output); // INPUTS: // g (gradients): backpropagated gradients // a (inputs): either the inputs that were passed to LeakyReluOp(), or its // outputs (using either one yields the same result here). // OUTPUT: // gradients to backprop template void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { OperateNoTemplate(context, g, a, alpha_, output); } private: T alpha_; }; template void LeakyReluGradOp::OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, T alpha, Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::LeakyReluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), alpha, output->flat()); }; template class EluOp : public UnaryElementWiseOp> { public: using UnaryElementWiseOp>::UnaryElementWiseOp; void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::Elu functor; functor(context->eigen_device(), input.flat(), output->flat()); } }; template class EluGradOp : public BinaryElementWiseOp> { public: using BinaryElementWiseOp>::BinaryElementWiseOp; void OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output); // INPUTS: // g (gradients): backpropagated gradients // a (outputs): outputs of the EluOp() // OUTPUT: // gradients to backprop template void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { OperateNoTemplate(context, g, a, output); } }; template void EluGradOp::OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::EluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), output->flat()); } template class SeluOp : public UnaryElementWiseOp> { public: using UnaryElementWiseOp>::UnaryElementWiseOp; void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { functor::Selu functor; functor(context->eigen_device(), input.flat(), output->flat()); } }; template class SeluGradOp : public BinaryElementWiseOp> { public: using BinaryElementWiseOp>::BinaryElementWiseOp; void OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output); // INPUTS: // g (gradients): backpropagated gradients // a (outputs): outputs of the SeluOp() // OUTPUT: // gradients to backprop template void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { OperateNoTemplate(context, g, a, output); } }; template void SeluGradOp::OperateNoTemplate(OpKernelContext* context, const Tensor& g, const Tensor& a, Tensor* output) { if (!ReluHelpers::ValidateSameSize(context, g, a)) return; functor::SeluGrad functor; functor(context->eigen_device(), g.flat(), a.flat(), output->flat()); } } // namespace tensorflow #undef EIGEN_USE_THREADS #endif // TENSORFLOW_CORE_KERNELS_RELU_OP_H_