aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/spacetodepth_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/spacetodepth_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/spacetodepth_op_test.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
index b76135764f..cd90d16aac 100644
--- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
@@ -34,8 +34,8 @@ from tensorflow.python.platform import tf_logging
class SpaceToDepthTest(test.TestCase):
- def _testOne(self, inputs, block_size, outputs):
- input_nhwc = math_ops.to_float(inputs)
+ def _testOne(self, inputs, block_size, outputs, dtype=dtypes.float32):
+ input_nhwc = math_ops.cast(inputs, dtype)
with self.test_session(use_gpu=False):
# test NHWC (default) on CPU
x_tf = array_ops.space_to_depth(input_nhwc, block_size)
@@ -58,6 +58,12 @@ class SpaceToDepthTest(test.TestCase):
x_out = [[[[1, 2, 3, 4]]]]
self._testOne(x_np, block_size, x_out)
+ def testBasicFloat16(self):
+ x_np = [[[[1], [2]], [[3], [4]]]]
+ block_size = 2
+ x_out = [[[[1, 2, 3, 4]]]]
+ self._testOne(x_np, block_size, x_out, dtype=dtypes.float16)
+
# Tests for larger input dimensions. To make sure elements are
# correctly ordered spatially.
def testLargerInput2x2(self):