aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/image_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-19 15:09:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 15:17:59 -0700
commit2d0a7087a14f015ea49f4b8feb70e0b5ecd41b28 (patch)
tree308854b6a62238e65c03e452cd31e0b8d6b28858 /tensorflow/compiler/tests/image_ops_test.py
parent470842748b9ee219fa0fcb8e3de25720960c83e3 (diff)
Only generate floating points that are fractions like n / 256, since they are RGB pixels. This fixes RGBToHSVTest.testBatch on low-precision dtypes like bfloat16.
PiperOrigin-RevId: 193581652
Diffstat (limited to 'tensorflow/compiler/tests/image_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 5b19e993ec..42e637734c 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -34,20 +34,23 @@ from tensorflow.python.ops import image_ops
from tensorflow.python.platform import test
+def GenerateNumpyRandomRGB(shape):
+ # Only generate floating points that are fractions like n / 256, since they
+ # are RGB pixels. Some low-precision floating point types in this test can't
+ # handle arbitrary precision floating points well.
+ return np.random.randint(0, 256, shape) / 256.
+
+
class RGBToHSVTest(XLATestCase):
def testBatch(self):
- # TODO(b/78230407): Reenable the test on GPU.
- if self.device == "XLA_GPU":
- return
-
# Build an arbitrary RGB image
np.random.seed(7)
batch_size = 5
shape = (batch_size, 2, 7, 3)
for nptype in self.float_types:
- inp = np.random.rand(*shape).astype(nptype)
+ inp = GenerateNumpyRandomRGB(shape).astype(nptype)
# Convert to HSV and back, as a batch and individually
with self.test_session() as sess:
@@ -87,7 +90,7 @@ class RGBToHSVTest(XLATestCase):
def testRGBToHSVNumpy(self):
"""Tests the RGB to HSV conversion matches a reference implementation."""
for nptype in self.float_types:
- rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype)
+ rgb_flat = GenerateNumpyRandomRGB((64, 3)).astype(nptype)
rgb_np = rgb_flat.reshape(4, 4, 4, 3)
hsv_np = np.array([
colorsys.rgb_to_hsv(