diff options
Diffstat (limited to 'tensorflow/core/kernels/training_ops_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/training_ops_gpu.cu.cc | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc new file mode 100644 index 0000000000..3106f29648 --- /dev/null +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -0,0 +1,127 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/training_ops.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { +template <typename T> +struct ApplyGradientDescent<GPUDevice, T> { + void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::ConstScalar alpha, + typename TTypes<T>::ConstFlat delta) { + Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; + bcast[0] = delta.dimension(0); + Eigen::Sizes<1> single; + var.device(d) -= alpha.reshape(single).broadcast(bcast) * delta; + } +}; + +template <typename T> +struct ApplyAdagrad<GPUDevice, T> { + void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat accum, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstFlat grad) { + accum.device(d) += grad.square(); + Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt(); + } +}; + +template <typename T> +struct ApplyMomentum<GPUDevice, T> { + void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat accum, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstFlat grad, + typename TTypes<T>::ConstScalar momentum) { + Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad; + var.device(d) -= lr.reshape(single).broadcast(bcast) * accum; + } +}; + +template <typename T> +struct ApplyAdam<GPUDevice, T> { + void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, + typename TTypes<T>::ConstScalar beta1_power, + typename TTypes<T>::ConstScalar beta2_power, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar beta1, + typename TTypes<T>::ConstScalar beta2, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstFlat grad) { + Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + const auto one = static_cast<T>(1.0); + m.device(d) = + m + + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * + (grad - m); + v.device(d) = + v + + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) * + (grad.square() - v); + var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / + (beta1_power.constant(one) - beta1_power)) + .reshape(single) + .broadcast(bcast) * + m / (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); + } +}; + +template <typename T> +struct ApplyRMSProp<GPUDevice, T> { + void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar rho, + typename TTypes<T>::ConstScalar momentum, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstFlat grad) { + Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + const auto one = static_cast<T>(1.0); + ms.device(d) = ms + + (rho.constant(one) - rho).reshape(single).broadcast(bcast) * + (grad.square() - ms); + mom.device(d) = + mom * momentum.reshape(single).broadcast(bcast) + + lr.reshape(single).broadcast(bcast) * grad / + ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt()); + var.device(d) -= mom; + } +}; + +} // namespace functor + +template struct functor::ApplyGradientDescent<GPUDevice, float>; +template struct functor::ApplyGradientDescent<GPUDevice, double>; + +template struct functor::ApplyAdagrad<GPUDevice, float>; +template struct functor::ApplyAdagrad<GPUDevice, double>; + +template struct functor::ApplyMomentum<GPUDevice, float>; +template struct functor::ApplyMomentum<GPUDevice, double>; + +template struct functor::ApplyAdam<GPUDevice, float>; +template struct functor::ApplyAdam<GPUDevice, double>; + +template struct functor::ApplyRMSProp<GPUDevice, float>; +template struct functor::ApplyRMSProp<GPUDevice, double>; +} // end namespace tensorflow + +#endif // GOOGLE_CUDA |