aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-22 18:48:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-22 18:52:57 -0700
commit4e9ce43b2060eab87d0f77718e676b17c71653a6 (patch)
treeec02b867bbd467769c52d91c4d7fa2927e620b46
parent2b4780b9fede3c864dc3fb01de8c0106afb70f86 (diff)
Convert Conv2D forward tests to run in both eager and graph modes.
PiperOrigin-RevId: 166146212
-rw-r--r--tensorflow/python/framework/test_util.py48
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py99
-rw-r--r--tensorflow/python/lib/core/py_func.cc3
-rw-r--r--tensorflow/python/platform/test.py40
4 files changed, 115 insertions, 75 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 7fdb99cdd2..7478861158 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -326,6 +326,54 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
return decorator
+def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
+ """Returns whether TensorFlow can access a GPU.
+
+ Args:
+ cuda_only: limit the search to CUDA gpus.
+ min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
+ CUDA compute capability required, or None if no requirement.
+
+ Returns:
+ True iff a gpu device of the requested kind is available.
+ """
+
+ def compute_capability_from_device_desc(device_desc):
+ # TODO(jingyue): The device description generator has to be in sync with
+ # this file. Another option is to put compute capability in
+ # DeviceAttributes, but I avoided that to keep DeviceAttributes
+ # target-independent. Reconsider this option when we have more things like
+ # this to keep in sync.
+ # LINT.IfChange
+ match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc)
+ # LINT.ThenChange(//tensorflow/core/\
+ # common_runtime/gpu/gpu_device.cc)
+ if not match:
+ return 0, 0
+ return int(match.group(1)), int(match.group(2))
+
+ for local_device in device_lib.list_local_devices():
+ if local_device.device_type == "GPU":
+ if (min_cuda_compute_capability is None or
+ compute_capability_from_device_desc(local_device.physical_device_desc)
+ >= min_cuda_compute_capability):
+ return True
+ if local_device.device_type == "SYCL" and not cuda_only:
+ return True
+ return False
+
+
+@contextlib.contextmanager
+def device(use_gpu):
+ """Uses gpu when requested and available."""
+ if use_gpu and is_gpu_available():
+ dev = "/device:GPU:0"
+ else:
+ dev = "/device:CPU:0"
+ with ops.device(dev):
+ yield
+
+
class TensorFlowTestCase(googletest.TestCase):
"""Base class for tests that need to test TensorFlow.
"""
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index cdf8aa2719..4c5b72671c 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -189,7 +189,8 @@ class Conv2DTest(test.TestCase):
# numbers from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
- with self.test_session(use_gpu=use_gpu):
+
+ with test_util.device(use_gpu):
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
strides = [1] + strides + [1]
@@ -219,7 +220,7 @@ class Conv2DTest(test.TestCase):
x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
def _SetupVal(data_format, use_gpu):
- with self.test_session(use_gpu=use_gpu):
+ with test_util.device(use_gpu):
t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t2 = constant_op.constant(x2, shape=filter_in_sizes)
strides = [1] + conv_strides + [1]
@@ -235,10 +236,9 @@ class Conv2DTest(test.TestCase):
tensors = []
for (data_format, use_gpu) in GetTestConfigs():
tensors.append(_SetupVal(data_format, use_gpu))
- with self.test_session() as sess:
- values = sess.run(tensors)
- for i in range(1, len(values)):
- self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
+ values = self.evaluate(tensors)
+ for i in range(1, len(values)):
+ self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
expected):
@@ -254,19 +254,19 @@ class Conv2DTest(test.TestCase):
dtype,
use_gpu=use_gpu)
tensors.append(result)
- with self.test_session() as sess:
- values = sess.run(tensors)
- for i in range(len(tensors)):
- conv = tensors[i]
- value = values[i]
- print("expected = ", expected)
- print("actual = ", value)
- tol = 1e-5
- if value.dtype == np.float16:
- tol = 1e-3
- self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol)
- self.assertShapeEqual(value, conv)
+ values = self.evaluate(tensors)
+ for i in range(len(tensors)):
+ conv = tensors[i]
+ value = values[i]
+ print("expected = ", expected)
+ print("actual = ", value)
+ tol = 1e-5
+ if value.dtype == np.float16:
+ tol = 1e-3
+ self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol)
+ self.assertShapeEqual(value, conv)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D1x1Filter(self):
expected_output = [
30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
@@ -279,6 +279,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DEmpty(self):
expected_output = []
self._VerifyValues(
@@ -288,6 +289,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Filter(self):
# The outputs are computed using third_party/py/IPython/notebook.
expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
@@ -298,6 +300,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D1x2Filter(self):
# The outputs are computed using third_party/py/IPython/notebook.
expected_output = [
@@ -311,6 +314,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2FilterStride2(self):
expected_output = [2271.0, 2367.0, 2463.0]
self._VerifyValues(
@@ -320,6 +324,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2FilterStride2Same(self):
expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
self._VerifyValues(
@@ -329,6 +334,7 @@ class Conv2DTest(test.TestCase):
padding="SAME",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2FilterStride1x2(self):
expected_output = [58.0, 78.0, 98.0, 118.0, 138.0, 158.0]
self._VerifyValues(
@@ -338,6 +344,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DKernelSmallerThanStrideValid(self):
expected_output = [65, 95, 275, 305]
self._VerifyValues(
@@ -347,6 +354,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DKernelSmallerThanStrideSame(self):
self._VerifyValues(
tensor_in_sizes=[1, 3, 3, 1],
@@ -369,6 +377,7 @@ class Conv2DTest(test.TestCase):
padding="SAME",
expected=[44, 28, 41, 16])
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DKernelSizeMatchesInputSize(self):
self._VerifyValues(
tensor_in_sizes=[1, 2, 2, 1],
@@ -397,7 +406,7 @@ class Conv2DTest(test.TestCase):
# numbers from 1.
x1 = [f * 1.0 for f in range(1, total_filter_size + 1)]
x2 = [f * 1.0 for f in range(1, total_output_size + 1)]
- with self.test_session(use_gpu=use_gpu) as sess:
+ with test_util.device(use_gpu):
if data_format == "NCHW":
input_sizes = test_util.NHWCToNCHW(input_sizes)
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
@@ -412,7 +421,7 @@ class Conv2DTest(test.TestCase):
if data_format == "NCHW":
conv = test_util.NCHWToNHWC(conv)
# "values" consists of two tensors for two backprops
- value = sess.run(conv)
+ value = self.evaluate(conv)
self.assertShapeEqual(value, conv)
print("expected = ", expected)
print("actual = ", value)
@@ -424,7 +433,7 @@ class Conv2DTest(test.TestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(data_format, use_gpu):
- with self.test_session(use_gpu=use_gpu):
+ with test_util.device(use_gpu):
if data_format == "NCHW":
new_input_sizes = test_util.NHWCToNCHW(input_sizes)
else:
@@ -445,7 +454,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format)
if data_format == "NCHW":
conv = test_util.NCHWToNHWC(conv)
- ret = conv.eval()
+ ret = self.evaluate(conv)
self.assertShapeEqual(ret, conv)
return ret
@@ -456,6 +465,7 @@ class Conv2DTest(test.TestCase):
for i in range(1, len(values)):
self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth1ValidBackpropInput(self):
expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -470,6 +480,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-5)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth3ValidBackpropInput(self):
expected_output = [
14.0, 32.0, 50.0, 100.0, 163.0, 226.0, 167.0, 212.0, 257.0, 122.0,
@@ -489,6 +500,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-4)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth3ValidBackpropInputStride1x2(self):
expected_output = [
1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 7.0, 12.0, 11.0, 18.0, 15.0, 24.0, 12.0,
@@ -506,6 +518,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-5)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DStrideTwoFilterOneSameBackpropInput(self):
expected_output = [
1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0,
@@ -523,6 +536,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-5)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DKernelSizeMatchesInputSizeBackpropInput(self):
expected_output = [5.0, 11.0, 17.0, 23.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -552,7 +566,7 @@ class Conv2DTest(test.TestCase):
x0 = [f * 1.0 for f in range(1, total_input_size + 1)]
x2 = [f * 1.0 for f in range(1, total_output_size + 1)]
for dtype in self._DtypesToTest(use_gpu=use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
+ with test_util.device(use_gpu):
t0 = constant_op.constant(x0, shape=input_sizes, dtype=dtype)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = constant_op.constant(x2, shape=output_sizes, dtype=dtype)
@@ -568,7 +582,7 @@ class Conv2DTest(test.TestCase):
strides=explicit_strides,
padding=padding,
data_format=data_format)
- value = sess.run(conv)
+ value = self.evaluate(conv)
self.assertShapeEqual(value, conv)
print("expected = ", expected)
print("actual = ", value)
@@ -580,7 +594,7 @@ class Conv2DTest(test.TestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(data_format, use_gpu):
- with self.test_session(use_gpu=use_gpu):
+ with test_util.device(use_gpu):
t0 = constant_op.constant(x0, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = constant_op.constant(x2, shape=output_sizes)
@@ -596,7 +610,7 @@ class Conv2DTest(test.TestCase):
strides=strides,
padding=padding,
data_format=data_format)
- ret = conv.eval()
+ ret = self.evaluate(conv)
self.assertShapeEqual(ret, conv)
return ret
@@ -606,6 +620,7 @@ class Conv2DTest(test.TestCase):
for i in range(1, len(values)):
self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth1ValidBackpropFilter(self):
expected = [5.0, 8.0, 14.0, 17.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -619,6 +634,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth3ValidBackpropFilter(self):
expected = [
17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0, 32.0, 43.0, 54.0,
@@ -637,6 +653,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self):
expected = [161.0, 182.0, 287.0, 308.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -650,6 +667,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DStrideTwoFilterOneSameBackpropFilter(self):
expected_output = [78.]
for (data_format, use_gpu) in GetTestConfigs():
@@ -663,6 +681,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DKernelSizeMatchesInputSizeBackpropFilter(self):
expected_output = [1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 4.0, 8.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -1446,13 +1465,18 @@ if __name__ == "__main__":
for index, (input_size_, filter_size_, output_size_, stride_,
padding_) in enumerate(GetShrunkInceptionShapes()):
setattr(Conv2DTest, "testInceptionFwd_" + str(index),
- GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
+ test_util.run_in_graph_and_eager_modes()(
+ GetInceptionFwdTest(input_size_, filter_size_, stride_,
+ padding_)))
setattr(Conv2DTest, "testInceptionBackInput_" + str(index),
- GetInceptionBackInputTest(input_size_, filter_size_, output_size_,
- stride_, padding_))
+ test_util.run_in_graph_and_eager_modes()(
+ GetInceptionBackInputTest(input_size_, filter_size_,
+ output_size_, stride_, padding_)))
setattr(Conv2DTest, "testInceptionBackFilter_" + str(index),
- GetInceptionBackFilterTest(input_size_, filter_size_, output_size_,
- [stride_, stride_], padding_))
+ test_util.run_in_graph_and_eager_modes()(
+ GetInceptionBackFilterTest(input_size_, filter_size_,
+ output_size_, [stride_, stride_],
+ padding_)))
# TODO(b/35359731)
# Fwd, BckInput, and BackFilter to test that for certain input parameter
@@ -1464,11 +1488,14 @@ if __name__ == "__main__":
fshape = [1, 1, 1, 256]
oshape = [1, 400, 400, 256]
setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused",
- GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))
+ test_util.run_in_graph_and_eager_modes()(
+ GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True)))
setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused",
- GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME",
- gpu_only=True))
+ test_util.run_in_graph_and_eager_modes()(
+ GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME",
+ gpu_only=True)))
setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused",
- GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME",
- gpu_only=True))
+ test_util.run_in_graph_and_eager_modes()(
+ GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME",
+ gpu_only=True)))
test.main()
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index a1618d5349..84cb4885f6 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -79,6 +79,9 @@ Status MakeArgTuple(PyCall* call, PyObject** tuple) {
// module.
Status NumericNpDTypeToTfDType(const int np, DataType* tf) {
switch (np) {
+ case NPY_FLOAT16:
+ *tf = DT_HALF;
+ break;
case NPY_FLOAT32:
*tf = DT_FLOAT;
break;
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index c759455218..72025f6717 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -40,7 +40,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
-from tensorflow.python.client import device_lib as _device_lib
from tensorflow.python.framework import test_util as _test_util
from tensorflow.python.platform import googletest as _googletest
from tensorflow.python.util.all_util import remove_undocumented
@@ -50,12 +49,12 @@ from tensorflow.python.framework.test_util import assert_equal_graph_def
from tensorflow.python.framework.test_util import create_local_cluster
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
from tensorflow.python.framework.test_util import gpu_device_name
+from tensorflow.python.framework.test_util import is_gpu_available
from tensorflow.python.ops.gradient_checker import compute_gradient_error
from tensorflow.python.ops.gradient_checker import compute_gradient
# pylint: enable=unused-import,g-bad-import-order
-import re as _re
import sys
if sys.version_info.major == 2:
import mock # pylint: disable=g-import-not-at-top,unused-import
@@ -103,43 +102,6 @@ def is_built_with_cuda():
return _test_util.IsGoogleCudaEnabled()
-def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
- """Returns whether TensorFlow can access a GPU.
-
- Args:
- cuda_only: limit the search to CUDA gpus.
- min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
- CUDA compute capability required, or None if no requirement.
-
- Returns:
- True iff a gpu device of the requested kind is available.
- """
-
- def compute_capability_from_device_desc(device_desc):
- # TODO(jingyue): The device description generator has to be in sync with
- # this file. Another option is to put compute capability in
- # DeviceAttributes, but I avoided that to keep DeviceAttributes
- # target-independent. Reconsider this option when we have more things like
- # this to keep in sync.
- # LINT.IfChange
- match = _re.search(r'compute capability: (\d+)\.(\d+)', device_desc)
- # LINT.ThenChange(//tensorflow/core/\
- # common_runtime/gpu/gpu_device.cc)
- if not match:
- return 0, 0
- return int(match.group(1)), int(match.group(2))
-
- for local_device in _device_lib.list_local_devices():
- if local_device.device_type == 'GPU':
- if (min_cuda_compute_capability is None or
- compute_capability_from_device_desc(local_device.physical_device_desc)
- >= min_cuda_compute_capability):
- return True
- if local_device.device_type == 'SYCL' and not cuda_only:
- return True
- return False
-
-
_allowed_symbols = [
# We piggy-back googletest documentation.
'Benchmark',