#ifndef TENSORFLOW_KERNELS_RELU_OP_H_ #define TENSORFLOW_KERNELS_RELU_OP_H_ // Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc. #include "tensorflow/core/framework/tensor_types.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace functor { // Functor used by ReluOp to do the computations. template struct Relu { // Computes Relu activation. // // features: any shape. // activations: same shape as "features". void operator()(const Device& d, typename TTypes::ConstTensor features, typename TTypes::Tensor activations) { activations.device(d) = features.cwiseMax(static_cast(0)); } }; // Functor used by ReluGradOp to do the computations. template struct ReluGrad { // Computes ReluGrad backprops. // // gradients: gradients backpropagated to the Relu op. // features: inputs that where passed to the Relu op. // backprops: gradients to backpropagate to the Relu inputs. void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, typename TTypes::Tensor backprops) { // NOTE: When the activation is exactly zero, we arbitrarily choose to not // propagate the associated gradient value. backprops.device(d) = gradients * (features > features.constant(static_cast(0))); } }; // Functor used by Relu6Op to do the computations. template struct Relu6 { // Computes Relu6 activation. // // features: any shape. // activations: same shape as "features". void operator()(const Device& d, typename TTypes::ConstTensor features, typename TTypes::Tensor activations) { activations.device(d) = features.cwiseMax(static_cast(0)).cwiseMin(static_cast(6)); } }; // Functor used by ReluGradOp to do the computations. template struct Relu6Grad { // Computes Relu6Grad backprops. // // gradients: gradients backpropagated to the Relu6 op. // features: inputs that where passed to the Relu6 op. // backprops: gradients to backpropagate to the Relu6 inputs. void operator()(const Device& d, typename TTypes::ConstTensor gradients, typename TTypes::ConstTensor features, typename TTypes::Tensor backprops) { // NOTE: When the activation is exactly zero or six, we // arbitrarily choose to not propagate the associated gradient // value. backprops.device(d) = gradients * (features > features.constant(static_cast(0))) * (features < features.constant(static_cast(6))); } }; } // namespace functor } // namespace tensorflow #endif // TENSORFLOW_KERNELS_RELU_OP_H_