aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2016-08-26 17:21:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-26 18:32:27 -0700
commitf2f582b3c00744c5e8857a309d38b374a5bd60fe (patch)
tree476640eca67e44817f129878451c00354c1966fc /tensorflow/core/kernels
parentba98c6b6b8aa38140bd5acdbeb2a9ca419bc6188 (diff)
Optimized the gradients of the sqrt, rsqrt, and inv functions
Change: 131463674
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_inverse.cc7
-rw-r--r--tensorflow/core/kernels/cwise_op_rsqrt.cc7
-rw-r--r--tensorflow/core/kernels/cwise_op_sqrt.cc7
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h83
7 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc
index 2872955589..2d8438f7e0 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY4(inverse, Eigen::half, float, double, int64);
+DEFINE_SIMPLE_BINARY3(inverse_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc
index f1316cbaf0..6a361cfeec 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY3(rsqrt, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(rsqrt_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc
index 8fba705343..dae93a0766 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY3(sqrt, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(sqrt_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_inverse.cc b/tensorflow/core/kernels/cwise_op_inverse.cc
index 855d49e6a2..e7af9b031f 100644
--- a/tensorflow/core/kernels/cwise_op_inverse.cc
+++ b/tensorflow/core/kernels/cwise_op_inverse.cc
@@ -22,4 +22,11 @@ REGISTER5(UnaryOp, CPU, "Inv", functor::inverse, float, Eigen::half, double,
REGISTER4(UnaryOp, GPU, "Inv", functor::inverse, float, Eigen::half, double,
int64);
#endif
+
+REGISTER5(SimpleBinaryOp, CPU, "InvGrad", functor::inverse_grad, float,
+ Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "InvGrad", functor::inverse_grad, float,
+ Eigen::half, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc
index eb66cb3850..3207166e94 100644
--- a/tensorflow/core/kernels/cwise_op_rsqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc
@@ -21,4 +21,11 @@ REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double,
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double);
#endif
+
+REGISTER5(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float,
+ Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "RsqrtGrad", functor::rsqrt_grad, float,
+ Eigen::half, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc
index 753dd9d0a0..aecffda4ba 100644
--- a/tensorflow/core/kernels/cwise_op_sqrt.cc
+++ b/tensorflow/core/kernels/cwise_op_sqrt.cc
@@ -21,4 +21,11 @@ REGISTER5(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double,
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double);
#endif
+
+REGISTER5(SimpleBinaryOp, CPU, "SqrtGrad", functor::sqrt_grad, float,
+ Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "SqrtGrad", functor::sqrt_grad, float,
+ Eigen::half, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
index 7dc29b67e1..47d5410d0a 100644
--- a/tensorflow/core/kernels/cwise_ops_gradients.h
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -69,6 +69,79 @@ struct functor_traits<scalar_sigmoid_gradient_op<T>> {
};
};
+// Gradient for the inverse function
+template <typename T>
+struct scalar_inverse_gradient_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& output, const T& output_gradient) const {
+ const T out_conj = numext::conj(output);
+ return -output_gradient * out_conj * out_conj;
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& output, const Packet& output_gradient) const {
+ const Packet out_conj = pconj(output);
+ return pnegate(pmul(output_gradient, pmul(out_conj, out_conj)));
+ }
+};
+template <typename T>
+struct functor_traits<scalar_inverse_gradient_op<T>> {
+ enum {
+ Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
+ PacketAccess = packet_traits<T>::HasMul,
+ };
+};
+
+// Gradient for the sqrt function
+template <typename T>
+struct scalar_sqrt_gradient_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& output, const T& output_gradient) const {
+ return static_cast<T>(0.5) * output_gradient / output;
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& output, const Packet& output_gradient) const {
+ const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
+ return pdiv(pmul(const_half, output_gradient), output);
+ }
+};
+template <typename T>
+struct functor_traits<scalar_sqrt_gradient_op<T>> {
+ enum {
+ PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv,
+ Cost =
+ NumTraits<T>::MulCost + NumTraits<T>::template Div<PacketAccess>::Cost,
+ };
+};
+
+// Gradient for the rsqrt function
+template <typename T>
+struct scalar_rsqrt_gradient_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& output, const T& output_gradient) const {
+ return static_cast<T>(-0.5) * (output_gradient * output) *
+ (output * output);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& output, const Packet& output_gradient) const {
+ const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
+ return pmul(const_half,
+ pmul(pmul(output_gradient, output), pmul(output, output)));
+ }
+};
+template <typename T>
+struct functor_traits<scalar_rsqrt_gradient_op<T>> {
+ enum {
+ Cost = 4 * NumTraits<T>::MulCost,
+ PacketAccess = packet_traits<T>::HasMul,
+ };
+};
+
} // end namespace internal
} // end namespace Eigen
@@ -102,6 +175,16 @@ template <typename T>
struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
};
+template <typename T>
+struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> {
+};
+
+template <typename T>
+struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {};
+
+template <typename T>
+struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
+
} // end namespace functor
} // end namespace tensorflow