diff options
author | 2018-09-06 19:20:11 +0000 | |
---|---|---|
committer | 2018-09-06 19:20:11 +0000 | |
commit | e3654a3cb4e26c26409aeeb9e127e3addcb14cee (patch) | |
tree | 9228b0118dae4a3b3179983ea8962365858304cf /tensorflow/contrib/image | |
parent | 84ada6e2ce3d830f5cf3490e30f408f7459d0eab (diff) |
Add float16 support on GPU for tf.contrib.image.transform
This fix tries to address the issue raised in 22115 where
there were no float16 support on GPU for tf.contrib.image.transform.
This fix fixes 22115.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r-- | tensorflow/contrib/image/kernels/image_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc | 1 |
2 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 370a8caf6a..788bf04b28 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -156,6 +156,7 @@ namespace functor { TF_CALL_uint8(DECLARE_FUNCTOR); TF_CALL_int32(DECLARE_FUNCTOR); TF_CALL_int64(DECLARE_FUNCTOR); +TF_CALL_half(DECLARE_FUNCTOR); TF_CALL_float(DECLARE_FUNCTOR); TF_CALL_double(DECLARE_FUNCTOR); @@ -175,6 +176,7 @@ TF_CALL_double(DECLARE_FUNCTOR); TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); TF_CALL_int64(REGISTER); +TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc index 8743a5ff72..36b9a236a6 100644 --- a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc @@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice; template class FillProjectiveTransform<GPUDevice, uint8>; template class FillProjectiveTransform<GPUDevice, int32>; template class FillProjectiveTransform<GPUDevice, int64>; +template class FillProjectiveTransform<GPUDevice, Eigen::half>; template class FillProjectiveTransform<GPUDevice, float>; template class FillProjectiveTransform<GPUDevice, double>; |