aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_ops_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/training_ops_gpu.cu.cc127
1 files changed, 127 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
new file mode 100644
index 0000000000..3106f29648
--- /dev/null
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -0,0 +1,127 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/training_ops.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T>
+struct ApplyGradientDescent<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::ConstScalar alpha,
+ typename TTypes<T>::ConstFlat delta) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = delta.dimension(0);
+ Eigen::Sizes<1> single;
+ var.device(d) -= alpha.reshape(single).broadcast(bcast) * delta;
+ }
+};
+
+template <typename T>
+struct ApplyAdagrad<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad) {
+ accum.device(d) += grad.square();
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
+ }
+};
+
+template <typename T>
+struct ApplyMomentum<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad,
+ typename TTypes<T>::ConstScalar momentum) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
+ var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
+ }
+};
+
+template <typename T>
+struct ApplyAdam<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar beta2_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ const auto one = static_cast<T>(1.0);
+ m.device(d) =
+ m +
+ (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
+ (grad - m);
+ v.device(d) =
+ v +
+ (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) *
+ (grad.square() - v);
+ var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
+ (beta1_power.constant(one) - beta1_power))
+ .reshape(single)
+ .broadcast(bcast) *
+ m / (epsilon.reshape(single).broadcast(bcast) + v.sqrt());
+ }
+};
+
+template <typename T>
+struct ApplyRMSProp<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar rho,
+ typename TTypes<T>::ConstScalar momentum,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ const auto one = static_cast<T>(1.0);
+ ms.device(d) = ms +
+ (rho.constant(one) - rho).reshape(single).broadcast(bcast) *
+ (grad.square() - ms);
+ mom.device(d) =
+ mom * momentum.reshape(single).broadcast(bcast) +
+ lr.reshape(single).broadcast(bcast) * grad /
+ ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
+ var.device(d) -= mom;
+ }
+};
+
+} // namespace functor
+
+template struct functor::ApplyGradientDescent<GPUDevice, float>;
+template struct functor::ApplyGradientDescent<GPUDevice, double>;
+
+template struct functor::ApplyAdagrad<GPUDevice, float>;
+template struct functor::ApplyAdagrad<GPUDevice, double>;
+
+template struct functor::ApplyMomentum<GPUDevice, float>;
+template struct functor::ApplyMomentum<GPUDevice, double>;
+
+template struct functor::ApplyAdam<GPUDevice, float>;
+template struct functor::ApplyAdam<GPUDevice, double>;
+
+template struct functor::ApplyRMSProp<GPUDevice, float>;
+template struct functor::ApplyRMSProp<GPUDevice, double>;
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA