aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/conv_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-27 04:24:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-27 05:34:02 -0700
commit4f257a2427ba0414bd7513c9b61fb835870bd3cf (patch)
treed74e61f75b385a40d927d3fa7d58bca3ac04bce0 /tensorflow/python/kernel_tests/conv_ops_test.py
parentc4cd581b55bcda3edff120b598919c216b0ae0f0 (diff)
Enable fp16 for convolution operations, gated on CUDA 7.5. (The fp16 tests
will not be run under 7.0.) This is GPU-only for now; there are still bugs in Eigen that block fp16 convolutions on CPU, but this should hopefully not last for long. Change: 123410990
Diffstat (limited to 'tensorflow/python/kernel_tests/conv_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py174
1 files changed, 95 insertions, 79 deletions
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 50bb643402..de3b8d0691 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -163,8 +163,16 @@ def GetTestConfigs():
class Conv2DTest(tf.test.TestCase):
+ def _DtypesToTest(self):
+ if test_util.CudaSupportsHalfMatMulAndConv():
+ # It is important that float32 comes before float16 here,
+ # as we will be using its gradients as reference for fp16 gradients.
+ return [tf.float32, tf.float16]
+ else:
+ return [tf.float32]
+
def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, strides,
- padding, data_format, use_gpu):
+ padding, data_format, dtype, use_gpu):
"""Verifies the output values of the convolution function.
Args:
@@ -175,6 +183,7 @@ class Conv2DTest(tf.test.TestCase):
strides: Stride: [col_stride, row_stride]
padding: Padding type.
data_format: Format of the data tensors.
+ dtype: Data type for inputs and outputs.
use_gpu: True if the operations should be run on GPU
Returns:
Symbolic tensor value that can be used to execute the computation
@@ -190,8 +199,8 @@ class Conv2DTest(tf.test.TestCase):
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
with self.test_session(use_gpu=use_gpu) as sess:
- t1 = tf.constant(x1, shape=tensor_in_sizes)
- t2 = tf.constant(x2, shape=filter_in_sizes)
+ t1 = tf.constant(x1, shape=tensor_in_sizes, dtype=dtype)
+ t2 = tf.constant(x2, shape=filter_in_sizes, dtype=dtype)
strides = [1] + strides + [1]
if data_format == "NCHW":
t1 = NHWCToNCHW(t1)
@@ -246,24 +255,27 @@ class Conv2DTest(tf.test.TestCase):
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides,
padding, expected):
- tensors = []
- for (data_format, use_gpu) in GetTestConfigs():
- result = self._SetupValuesForDevice(tensor_in_sizes,
- filter_in_sizes,
- strides,
- padding,
- data_format,
- use_gpu=use_gpu)
- tensors.append(result)
- with self.test_session() as sess:
- values = sess.run(tensors)
- for i in range(len(tensors)):
- conv = tensors[i]
- value = values[i]
- print("expected = ", expected)
- print("actual = ", value)
- self.assertArrayNear(expected, np.ravel(value), 1e-5)
- self.assertShapeEqual(value, conv)
+ for dtype in self._DtypesToTest():
+ print(dtype)
+ tensors = []
+ for (data_format, use_gpu) in GetTestConfigs():
+ result = self._SetupValuesForDevice(tensor_in_sizes,
+ filter_in_sizes,
+ strides,
+ padding,
+ data_format,
+ dtype,
+ use_gpu=use_gpu)
+ tensors.append(result)
+ with self.test_session() as sess:
+ values = sess.run(tensors)
+ for i in range(len(tensors)):
+ conv = tensors[i]
+ value = values[i]
+ print("expected = ", expected)
+ print("actual = ", value)
+ self.assertAllCloseAccordingToType(expected, np.ravel(value))
+ self.assertShapeEqual(value, conv)
def testConv2D1x1Filter(self):
expected_output = [30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0,
@@ -494,26 +506,27 @@ class Conv2DTest(tf.test.TestCase):
# numbers from 1.
x0 = [f * 1.0 for f in range(1, total_input_size + 1)]
x2 = [f * 1.0 for f in range(1, total_output_size + 1)]
- with self.test_session(use_gpu=use_gpu) as sess:
- t0 = tf.constant(x0, shape=input_sizes)
- t1 = tf.constant(filter_sizes, shape=[len(filter_sizes)])
- t2 = tf.constant(x2, shape=output_sizes)
- strides = [1] + strides + [1]
- if data_format == "NCHW":
- t0 = NHWCToNCHW(t0)
- t2 = NHWCToNCHW(t2)
- strides = NHWCToNCHW(strides)
- conv = tf.nn.conv2d_backprop_filter(t0,
- t1,
- t2,
- strides=strides,
- padding=padding,
- data_format=data_format)
- value = sess.run(conv)
- self.assertShapeEqual(value, conv)
- print("expected = ", expected)
- print("actual = ", value)
- self.assertArrayNear(expected, value.flatten(), 1e-5)
+ for dtype in self._DtypesToTest():
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t0 = tf.constant(x0, shape=input_sizes, dtype=dtype)
+ t1 = tf.constant(filter_sizes, shape=[len(filter_sizes)])
+ t2 = tf.constant(x2, shape=output_sizes, dtype=dtype)
+ explicit_strides = [1] + strides + [1]
+ if data_format == "NCHW":
+ t0 = NHWCToNCHW(t0)
+ t2 = NHWCToNCHW(t2)
+ explicit_strides = NHWCToNCHW(explicit_strides)
+ conv = tf.nn.conv2d_backprop_filter(t0,
+ t1,
+ t2,
+ strides=explicit_strides,
+ padding=padding,
+ data_format=data_format)
+ value = sess.run(conv)
+ self.assertShapeEqual(value, conv)
+ print("expected = ", expected)
+ print("actual = ", value)
+ self.assertArrayNear(expected, value.flatten(), 1e-5)
def _CompareBackFilter(self, input_sizes, filter_sizes, output_sizes,
conv_strides, padding):
@@ -618,44 +631,47 @@ class Conv2DTest(tf.test.TestCase):
filter_size *= x
input_data = [x * 1.0 / input_size for x in range(0, input_size)]
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
- with self.test_session(use_gpu=use_gpu):
- # Conv2DGrad functions are not compiled for double due to
- # a problem in the way Eigen's Conv2DGrad works for double.
- # So we disable the DOUBLE path. We should re-enable this
- # when double support returns for CPU and/or GPU.
- # data_type = tf.float64
- # tolerance = 1e-8
-
- data_type = tf.float32
- tolerance = 0.002
-
- input_tensor = tf.constant(input_data, shape=input_shape,
- dtype=data_type, name="input")
- filter_tensor = tf.constant(filter_data, shape=filter_shape,
- dtype=data_type, name="filter")
- strides = [1, stride_rows, stride_cols, 1]
- if data_format == "NCHW":
- new_input_tensor = NHWCToNCHW(input_tensor)
- strides = NHWCToNCHW(strides)
- else:
- new_input_tensor = input_tensor
- conv = tf.nn.conv2d(new_input_tensor,
- filter_tensor,
- strides,
- padding,
- data_format=data_format,
- name="conv")
- if data_format == "NCHW":
- conv = NCHWToNHWC(conv)
- self.assertEqual(output_shape, conv.get_shape())
- if test_input:
- err = tf.test.compute_gradient_error(input_tensor, input_shape, conv,
- output_shape)
- else:
- err = tf.test.compute_gradient_error(filter_tensor, filter_shape, conv,
- output_shape)
- print("conv_2d gradient error = ", err)
- self.assertLess(err, tolerance)
+ # Conv2DGrad functions are not compiled for double due to
+ # a problem in the way Eigen's Conv2DGrad works for double.
+ # So we disable the DOUBLE path. We should re-enable this
+ # when double support returns for CPU and/or GPU.
+ for dtype in self._DtypesToTest():
+ with self.test_session(use_gpu=use_gpu):
+ input_tensor = tf.constant(input_data, shape=input_shape,
+ dtype=dtype, name="input")
+ filter_tensor = tf.constant(filter_data, shape=filter_shape,
+ dtype=dtype, name="filter")
+ strides = [1, stride_rows, stride_cols, 1]
+ if data_format == "NCHW":
+ new_input_tensor = NHWCToNCHW(input_tensor)
+ strides = NHWCToNCHW(strides)
+ else:
+ new_input_tensor = input_tensor
+ conv = tf.nn.conv2d(new_input_tensor,
+ filter_tensor,
+ strides,
+ padding,
+ data_format=data_format,
+ name="conv")
+ if data_format == "NCHW":
+ conv = NCHWToNHWC(conv)
+ self.assertEqual(output_shape, conv.get_shape())
+ if test_input:
+ jacob_t, jacob_n = tf.test.compute_gradient(input_tensor, input_shape,
+ conv, output_shape)
+ else:
+ jacob_t, jacob_n = tf.test.compute_gradient(
+ filter_tensor, filter_shape, conv, output_shape)
+ if dtype == tf.float32:
+ reference_jacob_t = jacob_t
+ err = np.fabs(jacob_t - jacob_n).max()
+ else:
+ # Compare fp16 theoretical gradients to fp32 theoretical gradients,
+ # since fp16 numerical gradients are too imprecise.
+ err = np.fabs(jacob_t - reference_jacob_t).max()
+
+ print("conv_2d gradient error = ", err)
+ self.assertLess(err, 0.002)
def testInputGradientValidPaddingStrideOne(self):
for (data_format, use_gpu) in GetTestConfigs():