aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-06 19:22:39 +0000
committerGravatar Yong Tang <yong.tang.github@outlook.com>2018-09-06 19:22:39 +0000
commit7d7e8a725aeede4b724f7376d22df2c7f2ebdcf9 (patch)
treecfefdfad8fd22666387f0f5e6788dd631b6e38ba
parente3654a3cb4e26c26409aeeb9e127e3addcb14cee (diff)
Add test case for float16 support on GPU for tf.contrib.image.transform
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 376c0751ee..ef1f79bb94 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -272,6 +272,13 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
with self.cached_session():
self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval())
+ def test_transform_data_types(self):
+ for dtype in _DTYPES:
+ image = constant_op.constant([[1, 2], [3, 4]], dtype=dtype)
+ value = image_ops.transform(image, [1] * 8)
+ with self.test_session(use_gpu=True):
+ self.assertAllEqual(value.eval(), np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype()))
+
class BipartiteMatchTest(test_util.TensorFlowTestCase):