aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py96
1 files changed, 58 insertions, 38 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index c09e2d8084..1560766fc9 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -691,7 +691,7 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
- def _AssertProtoEquals(self, a, b):
+ def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.
Uses ProtoEq() first, as it returns correct results
@@ -701,11 +701,12 @@ class TensorFlowTestCase(googletest.TestCase):
Args:
a: a proto.
b: another proto.
+ msg: Optional message to report on failure.
"""
if not compare.ProtoEq(a, b):
- compare.assertProtoEqual(self, a, b, normalize_numbers=True)
+ compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
- def assertProtoEquals(self, expected_message_maybe_ascii, message):
+ def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
"""Asserts that message is same as parsed expected_message_ascii.
Creates another prototype of message, reads the ascii message into it and
@@ -714,8 +715,9 @@ class TensorFlowTestCase(googletest.TestCase):
Args:
expected_message_maybe_ascii: proto message in original or ascii form.
message: the message to validate.
+ msg: Optional message to report on failure.
"""
-
+ msg = msg if msg else ""
if isinstance(expected_message_maybe_ascii, type(message)):
expected_message = expected_message_maybe_ascii
self._AssertProtoEquals(expected_message, message)
@@ -725,20 +727,21 @@ class TensorFlowTestCase(googletest.TestCase):
expected_message_maybe_ascii,
expected_message,
descriptor_pool=descriptor_pool.Default())
- self._AssertProtoEquals(expected_message, message)
+ self._AssertProtoEquals(expected_message, message, msg=msg)
else:
- assert False, ("Can't compare protos of type %s and %s" %
- (type(expected_message_maybe_ascii), type(message)))
+ assert False, ("Can't compare protos of type %s and %s. %s" %
+ (type(expected_message_maybe_ascii), type(message), msg))
def assertProtoEqualsVersion(
self,
expected,
actual,
producer=versions.GRAPH_DEF_VERSION,
- min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER):
+ min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
+ msg=None):
expected = "versions { producer: %d min_consumer: %d };\n%s" % (
producer, min_consumer, expected)
- self.assertProtoEquals(expected, actual)
+ self.assertProtoEquals(expected, actual, msg=msg)
def assertStartsWith(self, actual, expected_start, msg=None):
"""Assert that actual.startswith(expected_start) is True.
@@ -1028,7 +1031,7 @@ class TensorFlowTestCase(googletest.TestCase):
"%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
if msg is not None else ""))
- def assertArrayNear(self, farray1, farray2, err):
+ def assertArrayNear(self, farray1, farray2, err, msg=None):
"""Asserts that two float arrays are near each other.
Checks that for all elements of farray1 and farray2
@@ -1038,23 +1041,25 @@ class TensorFlowTestCase(googletest.TestCase):
farray1: a list of float values.
farray2: a list of float values.
err: a float value.
+ msg: Optional message to report on failure.
"""
- self.assertEqual(len(farray1), len(farray2))
+ self.assertEqual(len(farray1), len(farray2), msg=msg)
for f1, f2 in zip(farray1, farray2):
- self.assertNear(float(f1), float(f2), err)
+ self.assertNear(float(f1), float(f2), err, msg=msg)
def _NDArrayNear(self, ndarray1, ndarray2, err):
return np.linalg.norm(ndarray1 - ndarray2) < err
- def assertNDArrayNear(self, ndarray1, ndarray2, err):
+ def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
"""Asserts that two numpy arrays have near values.
Args:
ndarray1: a numpy ndarray.
ndarray2: a numpy ndarray.
err: a float. The maximum absolute difference allowed.
+ msg: Optional message to report on failure.
"""
- self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err))
+ self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
def _GetNdArray(self, a):
if not isinstance(a, np.ndarray):
@@ -1096,9 +1101,16 @@ class TensorFlowTestCase(googletest.TestCase):
np.testing.assert_allclose(
a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
- def _assertAllCloseRecursive(self, a, b, rtol=1e-6, atol=1e-6, path=None):
+ def _assertAllCloseRecursive(self,
+ a,
+ b,
+ rtol=1e-6,
+ atol=1e-6,
+ path=None,
+ msg=None):
path = path or []
path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "")
+ msg = msg if msg else ""
# Check if a and/or b are namedtuples.
if hasattr(a, "_asdict"):
@@ -1107,18 +1119,18 @@ class TensorFlowTestCase(googletest.TestCase):
b = b._asdict()
a_is_dict = isinstance(a, dict)
if a_is_dict != isinstance(b, dict):
- raise ValueError("Can't compare dict to non-dict, a%s vs b%s." %
- (path_str, path_str))
+ raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
+ (path_str, path_str, msg))
if a_is_dict:
self.assertItemsEqual(
a.keys(),
b.keys(),
- msg="mismatched keys: a%s has keys %s, but b%s has keys %s" %
- (path_str, a.keys(), path_str, b.keys()))
+ msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
+ (path_str, a.keys(), path_str, b.keys(), msg))
for k in a:
path.append(k)
self._assertAllCloseRecursive(
- a[k], b[k], rtol=rtol, atol=atol, path=path)
+ a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
del path[-1]
elif isinstance(a, (list, tuple)):
# Try to directly compare a, b as ndarrays; if not work, then traverse
@@ -1131,17 +1143,17 @@ class TensorFlowTestCase(googletest.TestCase):
b_as_ndarray,
rtol=rtol,
atol=atol,
- msg="Mismatched value: a%s is different from b%s." % (path_str,
- path_str))
+ msg="Mismatched value: a%s is different from b%s. %s" %
+ (path_str, path_str, msg))
except (ValueError, TypeError) as e:
if len(a) != len(b):
raise ValueError(
- "Mismatched length: a%s has %d items, but b%s has %d items" %
- (path_str, len(a), path_str, len(b)))
+ "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
+ (path_str, len(a), path_str, len(b), msg))
for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
path.append(str(idx))
self._assertAllCloseRecursive(
- a_ele, b_ele, rtol=rtol, atol=atol, path=path)
+ a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
del path[-1]
# a and b are ndarray like objects
else:
@@ -1159,7 +1171,7 @@ class TensorFlowTestCase(googletest.TestCase):
e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:])
raise
- def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
+ def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
"""Asserts that two structures of numpy arrays, have near values.
`a` and `b` can be arbitrarily nested structures. A layer of a nested
@@ -1172,6 +1184,7 @@ class TensorFlowTestCase(googletest.TestCase):
numpy `ndarray`, or any arbitrarily nested of structure of these.
rtol: relative tolerance.
atol: absolute tolerance.
+ msg: Optional message to report on failure.
Raises:
ValueError: if only one of `a[p]` and `b[p]` is a dict or
@@ -1179,7 +1192,7 @@ class TensorFlowTestCase(googletest.TestCase):
to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
`[p] = [1]['d']`, then `a[p] = (6, 7)`.
"""
- self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol)
+ self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
def assertAllCloseAccordingToType(self,
a,
@@ -1191,7 +1204,8 @@ class TensorFlowTestCase(googletest.TestCase):
half_rtol=1e-3,
half_atol=1e-3,
bfloat16_rtol=1e-2,
- bfloat16_atol=1e-2):
+ bfloat16_atol=1e-2,
+ msg=None):
"""Like assertAllClose, but also suitable for comparing fp16 arrays.
In particular, the tolerance is reduced to 1e-3 if at least
@@ -1208,6 +1222,7 @@ class TensorFlowTestCase(googletest.TestCase):
half_atol: absolute tolerance for float16.
bfloat16_rtol: relative tolerance for bfloat16.
bfloat16_atol: absolute tolerance for bfloat16.
+ msg: Optional message to report on failure.
"""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
@@ -1224,19 +1239,21 @@ class TensorFlowTestCase(googletest.TestCase):
rtol = max(rtol, bfloat16_rtol)
atol = max(atol, bfloat16_atol)
- self.assertAllClose(a, b, rtol=rtol, atol=atol)
+ self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
- def assertAllEqual(self, a, b):
+ def assertAllEqual(self, a, b, msg=None):
"""Asserts that two numpy arrays have the same values.
Args:
a: the expected numpy ndarray or anything can be converted to one.
b: the actual numpy ndarray or anything can be converted to one.
+ msg: Optional message to report on failure.
"""
+ msg = msg if msg else ""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
- (a.shape, b.shape))
+ self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s."
+ " %s" % (a.shape, b.shape, msg))
same = (a == b)
if a.dtype == np.float32 or a.dtype == np.float64:
@@ -1253,7 +1270,7 @@ class TensorFlowTestCase(googletest.TestCase):
x, y = a, b
print("not equal lhs = ", x)
print("not equal rhs = ", y)
- np.testing.assert_array_equal(a, b)
+ np.testing.assert_array_equal(a, b, err_msg=msg)
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
@@ -1303,12 +1320,13 @@ class TensorFlowTestCase(googletest.TestCase):
return self.assertRaisesWithPredicateMatch(errors.OpError,
expected_err_re_or_predicate)
- def assertShapeEqual(self, np_array, tf_tensor):
+ def assertShapeEqual(self, np_array, tf_tensor, msg=None):
"""Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
Args:
np_array: A Numpy ndarray or Numpy scalar.
tf_tensor: A Tensor.
+ msg: Optional message to report on failure.
Raises:
TypeError: If the arguments have the wrong type.
@@ -1317,19 +1335,21 @@ class TensorFlowTestCase(googletest.TestCase):
raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
if not isinstance(tf_tensor, ops.Tensor):
raise TypeError("tf_tensor must be a Tensor")
- self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list())
+ self.assertAllEqual(
+ np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
- def assertDeviceEqual(self, device1, device2):
+ def assertDeviceEqual(self, device1, device2, msg=None):
"""Asserts that the two given devices are the same.
Args:
device1: A string device name or TensorFlow `DeviceSpec` object.
device2: A string device name or TensorFlow `DeviceSpec` object.
+ msg: Optional message to report on failure.
"""
device1 = pydev.canonical_name(device1)
device2 = pydev.canonical_name(device2)
- self.assertEqual(device1, device2,
- "Devices %s and %s are not equal" % (device1, device2))
+ self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" %
+ (device1, device2, msg))
# Fix Python 3 compatibility issues
if six.PY3: