diff options
Diffstat (limited to 'tensorflow/core/kernels/colorspace_op.cc')
-rw-r--r-- | tensorflow/core/kernels/colorspace_op.cc | 79 |
1 files changed, 46 insertions, 33 deletions
diff --git a/tensorflow/core/kernels/colorspace_op.cc b/tensorflow/core/kernels/colorspace_op.cc index 26f616f9b9..d65a34fd73 100644 --- a/tensorflow/core/kernels/colorspace_op.cc +++ b/tensorflow/core/kernels/colorspace_op.cc @@ -36,7 +36,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template <typename Device> +template <typename Device, typename T> class RGBToHSVOp : public OpKernel { public: explicit RGBToHSVOp(OpKernelConstruction* context) : OpKernel(context) {} @@ -59,23 +59,23 @@ class RGBToHSVOp : public OpKernel { // Make a canonical image, maintaining the last (channel) dimension, while // flattening all others do give the functor easy to work with data. - TTypes<float, 2>::ConstTensor input_data = input.flat_inner_dims<float>(); - TTypes<float, 2>::Tensor output_data = output->flat_inner_dims<float>(); + typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>(); + typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>(); Tensor trange; OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum<float>::value, + context, context->allocate_temp(DataTypeToEnum<T>::value, TensorShape({input_data.dimension(0)}), &trange)); - TTypes<float, 1>::Tensor range = trange.tensor<float, 1>(); + typename TTypes<T, 1>::Tensor range = trange.tensor<T, 1>(); - functor::RGBToHSV<Device>()(context->eigen_device<Device>(), input_data, - range, output_data); + functor::RGBToHSV<Device, T>()(context->eigen_device<Device>(), input_data, + range, output_data); } }; -template <typename Device> +template <typename Device, typename T> class HSVToRGBOp : public OpKernel { public: explicit HSVToRGBOp(OpKernelConstruction* context) : OpKernel(context) {} @@ -96,41 +96,54 @@ class HSVToRGBOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); - TTypes<float, 2>::ConstTensor input_data = input.flat_inner_dims<float>(); - TTypes<float, 2>::Tensor output_data = output->flat_inner_dims<float>(); + typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>(); + typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>(); - functor::HSVToRGB<Device>()(context->eigen_device<Device>(), input_data, - output_data); + functor::HSVToRGB<Device, T>()(context->eigen_device<Device>(), input_data, + output_data); } }; -REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_CPU), - RGBToHSVOp<CPUDevice>); -template class RGBToHSVOp<CPUDevice>; -REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_CPU), - HSVToRGBOp<CPUDevice>); -template class HSVToRGBOp<CPUDevice>; +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + RGBToHSVOp<CPUDevice, T>); \ + template class RGBToHSVOp<CPUDevice, T>; \ + REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + HSVToRGBOp<CPUDevice, T>); \ + template class HSVToRGBOp<CPUDevice, T>; +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); #if GOOGLE_CUDA // Forward declarations of the function specializations for GPU (to prevent // building the GPU versions here, they will be built compiling _gpu.cu.cc). namespace functor { -template <> -void RGBToHSV<GPUDevice>::operator()(const GPUDevice& d, - TTypes<float, 2>::ConstTensor input_data, - TTypes<float, 1>::Tensor range, - TTypes<float, 2>::Tensor output_data); -extern template struct RGBToHSV<GPUDevice>; -template <> -void HSVToRGB<GPUDevice>::operator()(const GPUDevice& d, - TTypes<float, 2>::ConstTensor input_data, - TTypes<float, 2>::Tensor output_data); -extern template struct HSVToRGB<GPUDevice>; +#define DECLARE_GPU(T) \ + template <> \ + void RGBToHSV<GPUDevice, T>::operator()(const GPUDevice& d, \ + TTypes<T, 2>::ConstTensor input_data, \ + TTypes<T, 1>::Tensor range, \ + TTypes<T, 2>::Tensor output_data); \ + extern template struct RGBToHSV<GPUDevice, T>; \ + template <> \ + void HSVToRGB<GPUDevice, T>::operator()(const GPUDevice& d, \ + TTypes<T, 2>::ConstTensor input_data, \ + TTypes<T, 2>::Tensor output_data); \ + extern template struct HSVToRGB<GPUDevice, T>; +TF_CALL_float(DECLARE_GPU); +TF_CALL_double(DECLARE_GPU); } // namespace functor -REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_GPU), - RGBToHSVOp<GPUDevice>); -REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_GPU), - HSVToRGBOp<GPUDevice>); +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_GPU) \ + .TypeConstraint<T>("T"), \ + RGBToHSVOp<GPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_GPU) \ + .TypeConstraint<T>("T"), \ + HSVToRGBOp<GPUDevice, T>); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); #endif } // namespace tensorflow |