aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 10:55:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 10:55:25 -0700
commit5fdf2474245b84759d218a6330f20d4bcfdf7427 (patch)
tree5e9f1b2465517d836ab438ff0baecdd71f027575
parentec2b5f889fb3eb677f7b8198cbd8d505b2779fa7 (diff)
parenta4eecdb369ecdae3b7fe7c1415d7b3b55bcc7b9e (diff)
Merge pull request #22122 from yongtang:22115-tf.contrib.image.transform-float16-gpu
PiperOrigin-RevId: 215240869
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h7
-rw-r--r--tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc1
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py9
4 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 370a8caf6a..788bf04b28 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -156,6 +156,7 @@ namespace functor {
TF_CALL_uint8(DECLARE_FUNCTOR);
TF_CALL_int32(DECLARE_FUNCTOR);
TF_CALL_int64(DECLARE_FUNCTOR);
+TF_CALL_half(DECLARE_FUNCTOR);
TF_CALL_float(DECLARE_FUNCTOR);
TF_CALL_double(DECLARE_FUNCTOR);
@@ -175,6 +176,7 @@ TF_CALL_double(DECLARE_FUNCTOR);
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index 6b63eed130..7fac774d07 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -71,14 +71,7 @@ class ProjectiveGenerator {
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
projection;
- // TODO(ringwalt): Add a fill value input.
-#if (defined __CUDA_ARCH__) && (CUDART_VERSION < 8000)
- // On CUDA versions previous to 8.0, only __shared__ variables
- // could be declared as static in the device code.
const T fill_value = T(0);
-#else
- static const T fill_value = T(0);
-#endif
switch (interpolation_) {
case INTERPOLATION_NEAREST:
// Switch the order of x and y again for indexing into the image.
diff --git a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
index 8743a5ff72..36b9a236a6 100644
--- a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
+++ b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
@@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
template class FillProjectiveTransform<GPUDevice, uint8>;
template class FillProjectiveTransform<GPUDevice, int32>;
template class FillProjectiveTransform<GPUDevice, int64>;
+template class FillProjectiveTransform<GPUDevice, Eigen::half>;
template class FillProjectiveTransform<GPUDevice, float>;
template class FillProjectiveTransform<GPUDevice, double>;
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 376c0751ee..4997c31a7f 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -272,6 +272,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
with self.cached_session():
self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval())
+ def test_transform_data_types(self):
+ for dtype in _DTYPES:
+ image = constant_op.constant([[1, 2], [3, 4]], dtype=dtype)
+ value = image_ops.transform(image, [1] * 8)
+ with self.test_session(use_gpu=True):
+ self.assertAllEqual(
+ value.eval(),
+ np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype()))
+
class BipartiteMatchTest(test_util.TensorFlowTestCase):