diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-19 15:09:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-19 15:17:59 -0700 |
commit | 2d0a7087a14f015ea49f4b8feb70e0b5ecd41b28 (patch) | |
tree | 308854b6a62238e65c03e452cd31e0b8d6b28858 /tensorflow/compiler/tests/image_ops_test.py | |
parent | 470842748b9ee219fa0fcb8e3de25720960c83e3 (diff) |
Only generate floating points that are fractions like n / 256, since they are RGB pixels. This fixes RGBToHSVTest.testBatch on low-precision dtypes like bfloat16.
PiperOrigin-RevId: 193581652
Diffstat (limited to 'tensorflow/compiler/tests/image_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/image_ops_test.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 5b19e993ec..42e637734c 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -34,20 +34,23 @@ from tensorflow.python.ops import image_ops from tensorflow.python.platform import test +def GenerateNumpyRandomRGB(shape): + # Only generate floating points that are fractions like n / 256, since they + # are RGB pixels. Some low-precision floating point types in this test can't + # handle arbitrary precision floating points well. + return np.random.randint(0, 256, shape) / 256. + + class RGBToHSVTest(XLATestCase): def testBatch(self): - # TODO(b/78230407): Reenable the test on GPU. - if self.device == "XLA_GPU": - return - # Build an arbitrary RGB image np.random.seed(7) batch_size = 5 shape = (batch_size, 2, 7, 3) for nptype in self.float_types: - inp = np.random.rand(*shape).astype(nptype) + inp = GenerateNumpyRandomRGB(shape).astype(nptype) # Convert to HSV and back, as a batch and individually with self.test_session() as sess: @@ -87,7 +90,7 @@ class RGBToHSVTest(XLATestCase): def testRGBToHSVNumpy(self): """Tests the RGB to HSV conversion matches a reference implementation.""" for nptype in self.float_types: - rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype) + rgb_flat = GenerateNumpyRandomRGB((64, 3)).astype(nptype) rgb_np = rgb_flat.reshape(4, 4, 4, 3) hsv_np = np.array([ colorsys.rgb_to_hsv( |