aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/l2loss_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/l2loss_op_gpu.cu.cc49
1 files changed, 3 insertions, 46 deletions
diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
index 73b6472254..420df37086 100644
--- a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
@@ -21,55 +21,12 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/kernels/reduction_ops_common.h"
-#include "tensorflow/core/kernels/reduction_ops_gpu_kernels.h"
-
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-
-// TODO(eriche): can add specialization for half2
-template <typename T>
-struct squareHalf {
- __host__ __device__ T operator()(const T& x) const {
- return static_cast<T>(0.5) * x * x;
- }
-};
-
-template <typename T>
-class L2LossOp<GPUDevice, T> : public OpKernel {
- public:
- explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- // The input tensor can be of any number of dimensions, even though it's
- // 2D in most typical applications.
- const Tensor& input = context->input(0);
- // The output is a single number.
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, TensorShape({}), &output));
- typedef cub::TransformInputIterator<T, squareHalf<T>, T*> inputIterType;
- inputIterType input_itr((T*)input.flat<T>().data(), squareHalf<T>());
- typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
-
- Constants<GPUDevice> constants;
- functor::ReduceImpl<T, cub::Sum, T*, inputIterType, ReductionAxes>(
- context, (T*)output->flat<T>().data(), input_itr, 1,
- input.flat<T>().size(), 1, 1, 0, constants.kZero, cub::Sum(), T(0));
- }
-};
-
-// Registration of the GPU implementations.
-#define REGISTER_GPU_KERNEL(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("L2Loss").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
- L2LossOp<GPUDevice, T>);
-
-REGISTER_GPU_KERNEL(float);
-REGISTER_GPU_KERNEL(double);
-REGISTER_GPU_KERNEL(Eigen::half);
-#undef REGISTER_GPU_KERNEL
+template struct functor::L2Loss<GPUDevice, float>;
+template struct functor::L2Loss<GPUDevice, double>;
+template struct functor::L2Loss<GPUDevice, Eigen::half>;
} // namespace tensorflow