diff options
Diffstat (limited to 'tensorflow/core/kernels/l2loss_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/l2loss_op_gpu.cu.cc | 49 |
1 files changed, 3 insertions, 46 deletions
diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc index 73b6472254..420df37086 100644 --- a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc +++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc @@ -21,55 +21,12 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/kernels/reduction_ops_common.h" -#include "tensorflow/core/kernels/reduction_ops_gpu_kernels.h" - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; - -// TODO(eriche): can add specialization for half2 -template <typename T> -struct squareHalf { - __host__ __device__ T operator()(const T& x) const { - return static_cast<T>(0.5) * x * x; - } -}; - -template <typename T> -class L2LossOp<GPUDevice, T> : public OpKernel { - public: - explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - // The input tensor can be of any number of dimensions, even though it's - // 2D in most typical applications. - const Tensor& input = context->input(0); - // The output is a single number. - Tensor* output = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({}), &output)); - typedef cub::TransformInputIterator<T, squareHalf<T>, T*> inputIterType; - inputIterType input_itr((T*)input.flat<T>().data(), squareHalf<T>()); - typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes; - - Constants<GPUDevice> constants; - functor::ReduceImpl<T, cub::Sum, T*, inputIterType, ReductionAxes>( - context, (T*)output->flat<T>().data(), input_itr, 1, - input.flat<T>().size(), 1, 1, 0, constants.kZero, cub::Sum(), T(0)); - } -}; - -// Registration of the GPU implementations. -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("L2Loss").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ - L2LossOp<GPUDevice, T>); - -REGISTER_GPU_KERNEL(float); -REGISTER_GPU_KERNEL(double); -REGISTER_GPU_KERNEL(Eigen::half); -#undef REGISTER_GPU_KERNEL +template struct functor::L2Loss<GPUDevice, float>; +template struct functor::L2Loss<GPUDevice, double>; +template struct functor::L2Loss<GPUDevice, Eigen::half>; } // namespace tensorflow |