diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-22 18:48:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-22 18:52:57 -0700 |
commit | 4e9ce43b2060eab87d0f77718e676b17c71653a6 (patch) | |
tree | ec02b867bbd467769c52d91c4d7fa2927e620b46 | |
parent | 2b4780b9fede3c864dc3fb01de8c0106afb70f86 (diff) |
Convert Conv2D forward tests to run in both eager and graph modes.
PiperOrigin-RevId: 166146212
-rw-r--r-- | tensorflow/python/framework/test_util.py | 48 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/conv_ops_test.py | 99 | ||||
-rw-r--r-- | tensorflow/python/lib/core/py_func.cc | 3 | ||||
-rw-r--r-- | tensorflow/python/platform/test.py | 40 |
4 files changed, 115 insertions, 75 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 7fdb99cdd2..7478861158 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -326,6 +326,54 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, return decorator +def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): + """Returns whether TensorFlow can access a GPU. + + Args: + cuda_only: limit the search to CUDA gpus. + min_cuda_compute_capability: a (major,minor) pair that indicates the minimum + CUDA compute capability required, or None if no requirement. + + Returns: + True iff a gpu device of the requested kind is available. + """ + + def compute_capability_from_device_desc(device_desc): + # TODO(jingyue): The device description generator has to be in sync with + # this file. Another option is to put compute capability in + # DeviceAttributes, but I avoided that to keep DeviceAttributes + # target-independent. Reconsider this option when we have more things like + # this to keep in sync. + # LINT.IfChange + match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc) + # LINT.ThenChange(//tensorflow/core/\ + # common_runtime/gpu/gpu_device.cc) + if not match: + return 0, 0 + return int(match.group(1)), int(match.group(2)) + + for local_device in device_lib.list_local_devices(): + if local_device.device_type == "GPU": + if (min_cuda_compute_capability is None or + compute_capability_from_device_desc(local_device.physical_device_desc) + >= min_cuda_compute_capability): + return True + if local_device.device_type == "SYCL" and not cuda_only: + return True + return False + + +@contextlib.contextmanager +def device(use_gpu): + """Uses gpu when requested and available.""" + if use_gpu and is_gpu_available(): + dev = "/device:GPU:0" + else: + dev = "/device:CPU:0" + with ops.device(dev): + yield + + class TensorFlowTestCase(googletest.TestCase): """Base class for tests that need to test TensorFlow. """ diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index cdf8aa2719..4c5b72671c 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -189,7 +189,8 @@ class Conv2DTest(test.TestCase): # numbers from 1. x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] - with self.test_session(use_gpu=use_gpu): + + with test_util.device(use_gpu): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) strides = [1] + strides + [1] @@ -219,7 +220,7 @@ class Conv2DTest(test.TestCase): x2 = np.random.rand(*filter_in_sizes).astype(np.float32) def _SetupVal(data_format, use_gpu): - with self.test_session(use_gpu=use_gpu): + with test_util.device(use_gpu): t1 = constant_op.constant(x1, shape=tensor_in_sizes) t2 = constant_op.constant(x2, shape=filter_in_sizes) strides = [1] + conv_strides + [1] @@ -235,10 +236,9 @@ class Conv2DTest(test.TestCase): tensors = [] for (data_format, use_gpu) in GetTestConfigs(): tensors.append(_SetupVal(data_format, use_gpu)) - with self.test_session() as sess: - values = sess.run(tensors) - for i in range(1, len(values)): - self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5) + values = self.evaluate(tensors) + for i in range(1, len(values)): + self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5) def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding, expected): @@ -254,19 +254,19 @@ class Conv2DTest(test.TestCase): dtype, use_gpu=use_gpu) tensors.append(result) - with self.test_session() as sess: - values = sess.run(tensors) - for i in range(len(tensors)): - conv = tensors[i] - value = values[i] - print("expected = ", expected) - print("actual = ", value) - tol = 1e-5 - if value.dtype == np.float16: - tol = 1e-3 - self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol) - self.assertShapeEqual(value, conv) + values = self.evaluate(tensors) + for i in range(len(tensors)): + conv = tensors[i] + value = values[i] + print("expected = ", expected) + print("actual = ", value) + tol = 1e-5 + if value.dtype == np.float16: + tol = 1e-3 + self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol) + self.assertShapeEqual(value, conv) + @test_util.run_in_graph_and_eager_modes() def testConv2D1x1Filter(self): expected_output = [ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, @@ -279,6 +279,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2DEmpty(self): expected_output = [] self._VerifyValues( @@ -288,6 +289,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] @@ -298,6 +300,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2D1x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [ @@ -311,6 +314,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterStride2(self): expected_output = [2271.0, 2367.0, 2463.0] self._VerifyValues( @@ -320,6 +324,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterStride2Same(self): expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] self._VerifyValues( @@ -329,6 +334,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterStride1x2(self): expected_output = [58.0, 78.0, 98.0, 118.0, 138.0, 158.0] self._VerifyValues( @@ -338,6 +344,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSmallerThanStrideValid(self): expected_output = [65, 95, 275, 305] self._VerifyValues( @@ -347,6 +354,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSmallerThanStrideSame(self): self._VerifyValues( tensor_in_sizes=[1, 3, 3, 1], @@ -369,6 +377,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=[44, 28, 41, 16]) + @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSizeMatchesInputSize(self): self._VerifyValues( tensor_in_sizes=[1, 2, 2, 1], @@ -397,7 +406,7 @@ class Conv2DTest(test.TestCase): # numbers from 1. x1 = [f * 1.0 for f in range(1, total_filter_size + 1)] x2 = [f * 1.0 for f in range(1, total_output_size + 1)] - with self.test_session(use_gpu=use_gpu) as sess: + with test_util.device(use_gpu): if data_format == "NCHW": input_sizes = test_util.NHWCToNCHW(input_sizes) t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) @@ -412,7 +421,7 @@ class Conv2DTest(test.TestCase): if data_format == "NCHW": conv = test_util.NCHWToNHWC(conv) # "values" consists of two tensors for two backprops - value = sess.run(conv) + value = self.evaluate(conv) self.assertShapeEqual(value, conv) print("expected = ", expected) print("actual = ", value) @@ -424,7 +433,7 @@ class Conv2DTest(test.TestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(data_format, use_gpu): - with self.test_session(use_gpu=use_gpu): + with test_util.device(use_gpu): if data_format == "NCHW": new_input_sizes = test_util.NHWCToNCHW(input_sizes) else: @@ -445,7 +454,7 @@ class Conv2DTest(test.TestCase): data_format=data_format) if data_format == "NCHW": conv = test_util.NCHWToNHWC(conv) - ret = conv.eval() + ret = self.evaluate(conv) self.assertShapeEqual(ret, conv) return ret @@ -456,6 +465,7 @@ class Conv2DTest(test.TestCase): for i in range(1, len(values)): self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1ValidBackpropInput(self): expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -470,6 +480,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth3ValidBackpropInput(self): expected_output = [ 14.0, 32.0, 50.0, 100.0, 163.0, 226.0, 167.0, 212.0, 257.0, 122.0, @@ -489,6 +500,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-4) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth3ValidBackpropInputStride1x2(self): expected_output = [ 1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 7.0, 12.0, 11.0, 18.0, 15.0, 24.0, 12.0, @@ -506,6 +518,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) + @test_util.run_in_graph_and_eager_modes() def testConv2DStrideTwoFilterOneSameBackpropInput(self): expected_output = [ 1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0, @@ -523,6 +536,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) + @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSizeMatchesInputSizeBackpropInput(self): expected_output = [5.0, 11.0, 17.0, 23.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -552,7 +566,7 @@ class Conv2DTest(test.TestCase): x0 = [f * 1.0 for f in range(1, total_input_size + 1)] x2 = [f * 1.0 for f in range(1, total_output_size + 1)] for dtype in self._DtypesToTest(use_gpu=use_gpu): - with self.test_session(use_gpu=use_gpu) as sess: + with test_util.device(use_gpu): t0 = constant_op.constant(x0, shape=input_sizes, dtype=dtype) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = constant_op.constant(x2, shape=output_sizes, dtype=dtype) @@ -568,7 +582,7 @@ class Conv2DTest(test.TestCase): strides=explicit_strides, padding=padding, data_format=data_format) - value = sess.run(conv) + value = self.evaluate(conv) self.assertShapeEqual(value, conv) print("expected = ", expected) print("actual = ", value) @@ -580,7 +594,7 @@ class Conv2DTest(test.TestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(data_format, use_gpu): - with self.test_session(use_gpu=use_gpu): + with test_util.device(use_gpu): t0 = constant_op.constant(x0, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = constant_op.constant(x2, shape=output_sizes) @@ -596,7 +610,7 @@ class Conv2DTest(test.TestCase): strides=strides, padding=padding, data_format=data_format) - ret = conv.eval() + ret = self.evaluate(conv) self.assertShapeEqual(ret, conv) return ret @@ -606,6 +620,7 @@ class Conv2DTest(test.TestCase): for i in range(1, len(values)): self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth1ValidBackpropFilter(self): expected = [5.0, 8.0, 14.0, 17.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -619,6 +634,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth3ValidBackpropFilter(self): expected = [ 17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0, 32.0, 43.0, 54.0, @@ -637,6 +653,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self): expected = [161.0, 182.0, 287.0, 308.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -650,6 +667,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) + @test_util.run_in_graph_and_eager_modes() def testConv2DStrideTwoFilterOneSameBackpropFilter(self): expected_output = [78.] for (data_format, use_gpu) in GetTestConfigs(): @@ -663,6 +681,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) + @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSizeMatchesInputSizeBackpropFilter(self): expected_output = [1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 4.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -1446,13 +1465,18 @@ if __name__ == "__main__": for index, (input_size_, filter_size_, output_size_, stride_, padding_) in enumerate(GetShrunkInceptionShapes()): setattr(Conv2DTest, "testInceptionFwd_" + str(index), - GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_)) + test_util.run_in_graph_and_eager_modes()( + GetInceptionFwdTest(input_size_, filter_size_, stride_, + padding_))) setattr(Conv2DTest, "testInceptionBackInput_" + str(index), - GetInceptionBackInputTest(input_size_, filter_size_, output_size_, - stride_, padding_)) + test_util.run_in_graph_and_eager_modes()( + GetInceptionBackInputTest(input_size_, filter_size_, + output_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackFilter_" + str(index), - GetInceptionBackFilterTest(input_size_, filter_size_, output_size_, - [stride_, stride_], padding_)) + test_util.run_in_graph_and_eager_modes()( + GetInceptionBackFilterTest(input_size_, filter_size_, + output_size_, [stride_, stride_], + padding_))) # TODO(b/35359731) # Fwd, BckInput, and BackFilter to test that for certain input parameter @@ -1464,11 +1488,14 @@ if __name__ == "__main__": fshape = [1, 1, 1, 256] oshape = [1, 400, 400, 256] setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused", - GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True)) + test_util.run_in_graph_and_eager_modes()( + GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))) setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused", - GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME", - gpu_only=True)) + test_util.run_in_graph_and_eager_modes()( + GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME", + gpu_only=True))) setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused", - GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME", - gpu_only=True)) + test_util.run_in_graph_and_eager_modes()( + GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME", + gpu_only=True))) test.main() diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index a1618d5349..84cb4885f6 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -79,6 +79,9 @@ Status MakeArgTuple(PyCall* call, PyObject** tuple) { // module. Status NumericNpDTypeToTfDType(const int np, DataType* tf) { switch (np) { + case NPY_FLOAT16: + *tf = DT_HALF; + break; case NPY_FLOAT32: *tf = DT_FLOAT; break; diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index c759455218..72025f6717 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -40,7 +40,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order -from tensorflow.python.client import device_lib as _device_lib from tensorflow.python.framework import test_util as _test_util from tensorflow.python.platform import googletest as _googletest from tensorflow.python.util.all_util import remove_undocumented @@ -50,12 +49,12 @@ from tensorflow.python.framework.test_util import assert_equal_graph_def from tensorflow.python.framework.test_util import create_local_cluster from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase from tensorflow.python.framework.test_util import gpu_device_name +from tensorflow.python.framework.test_util import is_gpu_available from tensorflow.python.ops.gradient_checker import compute_gradient_error from tensorflow.python.ops.gradient_checker import compute_gradient # pylint: enable=unused-import,g-bad-import-order -import re as _re import sys if sys.version_info.major == 2: import mock # pylint: disable=g-import-not-at-top,unused-import @@ -103,43 +102,6 @@ def is_built_with_cuda(): return _test_util.IsGoogleCudaEnabled() -def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): - """Returns whether TensorFlow can access a GPU. - - Args: - cuda_only: limit the search to CUDA gpus. - min_cuda_compute_capability: a (major,minor) pair that indicates the minimum - CUDA compute capability required, or None if no requirement. - - Returns: - True iff a gpu device of the requested kind is available. - """ - - def compute_capability_from_device_desc(device_desc): - # TODO(jingyue): The device description generator has to be in sync with - # this file. Another option is to put compute capability in - # DeviceAttributes, but I avoided that to keep DeviceAttributes - # target-independent. Reconsider this option when we have more things like - # this to keep in sync. - # LINT.IfChange - match = _re.search(r'compute capability: (\d+)\.(\d+)', device_desc) - # LINT.ThenChange(//tensorflow/core/\ - # common_runtime/gpu/gpu_device.cc) - if not match: - return 0, 0 - return int(match.group(1)), int(match.group(2)) - - for local_device in _device_lib.list_local_devices(): - if local_device.device_type == 'GPU': - if (min_cuda_compute_capability is None or - compute_capability_from_device_desc(local_device.physical_device_desc) - >= min_cuda_compute_capability): - return True - if local_device.device_type == 'SYCL' and not cuda_only: - return True - return False - - _allowed_symbols = [ # We piggy-back googletest documentation. 'Benchmark', |