aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/colorspace_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/colorspace_op.cc')
-rw-r--r--tensorflow/core/kernels/colorspace_op.cc79
1 files changed, 46 insertions, 33 deletions
diff --git a/tensorflow/core/kernels/colorspace_op.cc b/tensorflow/core/kernels/colorspace_op.cc
index 26f616f9b9..d65a34fd73 100644
--- a/tensorflow/core/kernels/colorspace_op.cc
+++ b/tensorflow/core/kernels/colorspace_op.cc
@@ -36,7 +36,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device>
+template <typename Device, typename T>
class RGBToHSVOp : public OpKernel {
public:
explicit RGBToHSVOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -59,23 +59,23 @@ class RGBToHSVOp : public OpKernel {
// Make a canonical image, maintaining the last (channel) dimension, while
// flattening all others do give the functor easy to work with data.
- TTypes<float, 2>::ConstTensor input_data = input.flat_inner_dims<float>();
- TTypes<float, 2>::Tensor output_data = output->flat_inner_dims<float>();
+ typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>();
+ typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>();
Tensor trange;
OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<float>::value,
+ context, context->allocate_temp(DataTypeToEnum<T>::value,
TensorShape({input_data.dimension(0)}),
&trange));
- TTypes<float, 1>::Tensor range = trange.tensor<float, 1>();
+ typename TTypes<T, 1>::Tensor range = trange.tensor<T, 1>();
- functor::RGBToHSV<Device>()(context->eigen_device<Device>(), input_data,
- range, output_data);
+ functor::RGBToHSV<Device, T>()(context->eigen_device<Device>(), input_data,
+ range, output_data);
}
};
-template <typename Device>
+template <typename Device, typename T>
class HSVToRGBOp : public OpKernel {
public:
explicit HSVToRGBOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -96,41 +96,54 @@ class HSVToRGBOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
- TTypes<float, 2>::ConstTensor input_data = input.flat_inner_dims<float>();
- TTypes<float, 2>::Tensor output_data = output->flat_inner_dims<float>();
+ typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>();
+ typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>();
- functor::HSVToRGB<Device>()(context->eigen_device<Device>(), input_data,
- output_data);
+ functor::HSVToRGB<Device, T>()(context->eigen_device<Device>(), input_data,
+ output_data);
}
};
-REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_CPU),
- RGBToHSVOp<CPUDevice>);
-template class RGBToHSVOp<CPUDevice>;
-REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_CPU),
- HSVToRGBOp<CPUDevice>);
-template class HSVToRGBOp<CPUDevice>;
+#define REGISTER_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ RGBToHSVOp<CPUDevice, T>); \
+ template class RGBToHSVOp<CPUDevice, T>; \
+ REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ HSVToRGBOp<CPUDevice, T>); \
+ template class HSVToRGBOp<CPUDevice, T>;
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
#if GOOGLE_CUDA
// Forward declarations of the function specializations for GPU (to prevent
// building the GPU versions here, they will be built compiling _gpu.cu.cc).
namespace functor {
-template <>
-void RGBToHSV<GPUDevice>::operator()(const GPUDevice& d,
- TTypes<float, 2>::ConstTensor input_data,
- TTypes<float, 1>::Tensor range,
- TTypes<float, 2>::Tensor output_data);
-extern template struct RGBToHSV<GPUDevice>;
-template <>
-void HSVToRGB<GPUDevice>::operator()(const GPUDevice& d,
- TTypes<float, 2>::ConstTensor input_data,
- TTypes<float, 2>::Tensor output_data);
-extern template struct HSVToRGB<GPUDevice>;
+#define DECLARE_GPU(T) \
+ template <> \
+ void RGBToHSV<GPUDevice, T>::operator()(const GPUDevice& d, \
+ TTypes<T, 2>::ConstTensor input_data, \
+ TTypes<T, 1>::Tensor range, \
+ TTypes<T, 2>::Tensor output_data); \
+ extern template struct RGBToHSV<GPUDevice, T>; \
+ template <> \
+ void HSVToRGB<GPUDevice, T>::operator()(const GPUDevice& d, \
+ TTypes<T, 2>::ConstTensor input_data, \
+ TTypes<T, 2>::Tensor output_data); \
+ extern template struct HSVToRGB<GPUDevice, T>;
+TF_CALL_float(DECLARE_GPU);
+TF_CALL_double(DECLARE_GPU);
} // namespace functor
-REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_GPU),
- RGBToHSVOp<GPUDevice>);
-REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_GPU),
- HSVToRGBOp<GPUDevice>);
+#define REGISTER_GPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T"), \
+ RGBToHSVOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T"), \
+ HSVToRGBOp<GPUDevice, T>);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
#endif
} // namespace tensorflow