aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/constant_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/constant_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/constant_op_gpu.cu.cc89
1 files changed, 89 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc
new file mode 100644
index 0000000000..64502378bd
--- /dev/null
+++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc
@@ -0,0 +1,89 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/fill_functor.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace Eigen {
+namespace internal {
+
+template <typename T>
+struct scalar_const_op {
+ typedef typename packet_traits<T>::type Packet;
+
+ const T* val;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ scalar_const_op(const scalar_const_op& x)
+ : val(x.val) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {}
+
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(Index,
+ Index = 0) const {
+ return *val;
+ }
+
+ template <typename Index>
+ EIGEN_STRONG_INLINE const Packet packetOp(Index, Index = 0) const {
+ return internal::pset1<Packet>(*val);
+ }
+};
+
+template <typename T>
+struct functor_traits<scalar_const_op<T> > {
+ enum {
+ Cost = 1,
+ PacketAccess = packet_traits<T>::Vectorizable,
+ IsRepeatable = true
+ };
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+namespace tensorflow {
+
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization FillFunctor<Device=GPUDevice, T>
+template <typename T>
+struct FillFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstScalar in) {
+ Eigen::internal::scalar_const_op<T> f(in.data());
+ out.device(d) = out.nullaryExpr(f);
+ }
+};
+
+#define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>
+DEFINE_FILL_GPU(float);
+DEFINE_FILL_GPU(double);
+DEFINE_FILL_GPU(int32);
+DEFINE_FILL_GPU(uint8);
+DEFINE_FILL_GPU(int16);
+DEFINE_FILL_GPU(int8);
+DEFINE_FILL_GPU(int64);
+#undef DEFINE_FILL_GPU
+
+// Partial specialization of FillFunctor<Device=GPUDevice, T>.
+template <typename T>
+struct SetZeroFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
+ out.device(d) = out.constant(0);
+ }
+};
+
+#define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>
+DEFINE_SETZERO_GPU(float);
+#undef DEFINE_SETZERO_GPU
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA