diff options
author | 2016-08-26 17:21:44 -0800 | |
---|---|---|
committer | 2016-08-26 18:32:27 -0700 | |
commit | f2f582b3c00744c5e8857a309d38b374a5bd60fe (patch) | |
tree | 476640eca67e44817f129878451c00354c1966fc /tensorflow/core/kernels | |
parent | ba98c6b6b8aa38140bd5acdbeb2a9ca419bc6188 (diff) |
Optimized the gradients of the sqrt, rsqrt, and inv functions
Change: 131463674
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/cwise_op_inverse.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/cwise_op_rsqrt.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/cwise_op_sqrt.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_gradients.h | 83 |
7 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc index 2872955589..2d8438f7e0 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc @@ -16,10 +16,12 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" +#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h" namespace tensorflow { namespace functor { DEFINE_UNARY4(inverse, Eigen::half, float, double, int64); +DEFINE_SIMPLE_BINARY3(inverse_grad, Eigen::half, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc index f1316cbaf0..6a361cfeec 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc @@ -16,10 +16,12 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" +#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h" namespace tensorflow { namespace functor { DEFINE_UNARY3(rsqrt, Eigen::half, float, double); +DEFINE_SIMPLE_BINARY3(rsqrt_grad, Eigen::half, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc index 8fba705343..dae93a0766 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc @@ -16,10 +16,12 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" +#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h" namespace tensorflow { namespace functor { DEFINE_UNARY3(sqrt, Eigen::half, float, double); +DEFINE_SIMPLE_BINARY3(sqrt_grad, Eigen::half, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_inverse.cc b/tensorflow/core/kernels/cwise_op_inverse.cc index 855d49e6a2..e7af9b031f 100644 --- a/tensorflow/core/kernels/cwise_op_inverse.cc +++ b/tensorflow/core/kernels/cwise_op_inverse.cc @@ -22,4 +22,11 @@ REGISTER5(UnaryOp, CPU, "Inv", functor::inverse, float, Eigen::half, double, REGISTER4(UnaryOp, GPU, "Inv", functor::inverse, float, Eigen::half, double, int64); #endif + +REGISTER5(SimpleBinaryOp, CPU, "InvGrad", functor::inverse_grad, float, + Eigen::half, double, complex64, complex128); +#if GOOGLE_CUDA +REGISTER3(SimpleBinaryOp, GPU, "InvGrad", functor::inverse_grad, float, + Eigen::half, double); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc index eb66cb3850..3207166e94 100644 --- a/tensorflow/core/kernels/cwise_op_rsqrt.cc +++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc @@ -21,4 +21,11 @@ REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double, #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double); #endif + +REGISTER5(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float, + Eigen::half, double, complex64, complex128); +#if GOOGLE_CUDA +REGISTER3(SimpleBinaryOp, GPU, "RsqrtGrad", functor::rsqrt_grad, float, + Eigen::half, double); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc index 753dd9d0a0..aecffda4ba 100644 --- a/tensorflow/core/kernels/cwise_op_sqrt.cc +++ b/tensorflow/core/kernels/cwise_op_sqrt.cc @@ -21,4 +21,11 @@ REGISTER5(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double, #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double); #endif + +REGISTER5(SimpleBinaryOp, CPU, "SqrtGrad", functor::sqrt_grad, float, + Eigen::half, double, complex64, complex128); +#if GOOGLE_CUDA +REGISTER3(SimpleBinaryOp, GPU, "SqrtGrad", functor::sqrt_grad, float, + Eigen::half, double); +#endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h index 7dc29b67e1..47d5410d0a 100644 --- a/tensorflow/core/kernels/cwise_ops_gradients.h +++ b/tensorflow/core/kernels/cwise_ops_gradients.h @@ -69,6 +69,79 @@ struct functor_traits<scalar_sigmoid_gradient_op<T>> { }; }; +// Gradient for the inverse function +template <typename T> +struct scalar_inverse_gradient_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + const T out_conj = numext::conj(output); + return -output_gradient * out_conj * out_conj; + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + const Packet out_conj = pconj(output); + return pnegate(pmul(output_gradient, pmul(out_conj, out_conj))); + } +}; +template <typename T> +struct functor_traits<scalar_inverse_gradient_op<T>> { + enum { + Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, + PacketAccess = packet_traits<T>::HasMul, + }; +}; + +// Gradient for the sqrt function +template <typename T> +struct scalar_sqrt_gradient_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + return static_cast<T>(0.5) * output_gradient / output; + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + const Packet const_half = pset1<Packet>(static_cast<T>(0.5)); + return pdiv(pmul(const_half, output_gradient), output); + } +}; +template <typename T> +struct functor_traits<scalar_sqrt_gradient_op<T>> { + enum { + PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv, + Cost = + NumTraits<T>::MulCost + NumTraits<T>::template Div<PacketAccess>::Cost, + }; +}; + +// Gradient for the rsqrt function +template <typename T> +struct scalar_rsqrt_gradient_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + return static_cast<T>(-0.5) * (output_gradient * output) * + (output * output); + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + const Packet const_half = pset1<Packet>(static_cast<T>(-0.5)); + return pmul(const_half, + pmul(pmul(output_gradient, output), pmul(output, output))); + } +}; +template <typename T> +struct functor_traits<scalar_rsqrt_gradient_op<T>> { + enum { + Cost = 4 * NumTraits<T>::MulCost, + PacketAccess = packet_traits<T>::HasMul, + }; +}; + } // end namespace internal } // end namespace Eigen @@ -102,6 +175,16 @@ template <typename T> struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> { }; +template <typename T> +struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> { +}; + +template <typename T> +struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {}; + +template <typename T> +struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {}; + } // end namespace functor } // end namespace tensorflow |