aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-06 19:20:11 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-06 19:20:11 +0000
commite3654a3cb4e26c26409aeeb9e127e3addcb14cee (patch)
tree9228b0118dae4a3b3179983ea8962365858304cf /tensorflow/contrib/image
parent84ada6e2ce3d830f5cf3490e30f408f7459d0eab (diff)
Add float16 support on GPU for tf.contrib.image.transform
This fix tries to address the issue raised in 22115 where there were no float16 support on GPU for tf.contrib.image.transform. This fix fixes 22115. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc1
2 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 370a8caf6a..788bf04b28 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -156,6 +156,7 @@ namespace functor {
TF_CALL_uint8(DECLARE_FUNCTOR);
TF_CALL_int32(DECLARE_FUNCTOR);
TF_CALL_int64(DECLARE_FUNCTOR);
+TF_CALL_half(DECLARE_FUNCTOR);
TF_CALL_float(DECLARE_FUNCTOR);
TF_CALL_double(DECLARE_FUNCTOR);
@@ -175,6 +176,7 @@ TF_CALL_double(DECLARE_FUNCTOR);
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
index 8743a5ff72..36b9a236a6 100644
--- a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
+++ b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
@@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
template class FillProjectiveTransform<GPUDevice, uint8>;
template class FillProjectiveTransform<GPUDevice, int32>;
template class FillProjectiveTransform<GPUDevice, int64>;
+template class FillProjectiveTransform<GPUDevice, Eigen::half>;
template class FillProjectiveTransform<GPUDevice, float>;
template class FillProjectiveTransform<GPUDevice, double>;