diff options
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 96 |
1 files changed, 58 insertions, 38 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index c09e2d8084..1560766fc9 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -691,7 +691,7 @@ class TensorFlowTestCase(googletest.TestCase): self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) return self._tempdir - def _AssertProtoEquals(self, a, b): + def _AssertProtoEquals(self, a, b, msg=None): """Asserts that a and b are the same proto. Uses ProtoEq() first, as it returns correct results @@ -701,11 +701,12 @@ class TensorFlowTestCase(googletest.TestCase): Args: a: a proto. b: another proto. + msg: Optional message to report on failure. """ if not compare.ProtoEq(a, b): - compare.assertProtoEqual(self, a, b, normalize_numbers=True) + compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) - def assertProtoEquals(self, expected_message_maybe_ascii, message): + def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): """Asserts that message is same as parsed expected_message_ascii. Creates another prototype of message, reads the ascii message into it and @@ -714,8 +715,9 @@ class TensorFlowTestCase(googletest.TestCase): Args: expected_message_maybe_ascii: proto message in original or ascii form. message: the message to validate. + msg: Optional message to report on failure. """ - + msg = msg if msg else "" if isinstance(expected_message_maybe_ascii, type(message)): expected_message = expected_message_maybe_ascii self._AssertProtoEquals(expected_message, message) @@ -725,20 +727,21 @@ class TensorFlowTestCase(googletest.TestCase): expected_message_maybe_ascii, expected_message, descriptor_pool=descriptor_pool.Default()) - self._AssertProtoEquals(expected_message, message) + self._AssertProtoEquals(expected_message, message, msg=msg) else: - assert False, ("Can't compare protos of type %s and %s" % - (type(expected_message_maybe_ascii), type(message))) + assert False, ("Can't compare protos of type %s and %s. %s" % + (type(expected_message_maybe_ascii), type(message), msg)) def assertProtoEqualsVersion( self, expected, actual, producer=versions.GRAPH_DEF_VERSION, - min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER): + min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, + msg=None): expected = "versions { producer: %d min_consumer: %d };\n%s" % ( producer, min_consumer, expected) - self.assertProtoEquals(expected, actual) + self.assertProtoEquals(expected, actual, msg=msg) def assertStartsWith(self, actual, expected_start, msg=None): """Assert that actual.startswith(expected_start) is True. @@ -1028,7 +1031,7 @@ class TensorFlowTestCase(googletest.TestCase): "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg if msg is not None else "")) - def assertArrayNear(self, farray1, farray2, err): + def assertArrayNear(self, farray1, farray2, err, msg=None): """Asserts that two float arrays are near each other. Checks that for all elements of farray1 and farray2 @@ -1038,23 +1041,25 @@ class TensorFlowTestCase(googletest.TestCase): farray1: a list of float values. farray2: a list of float values. err: a float value. + msg: Optional message to report on failure. """ - self.assertEqual(len(farray1), len(farray2)) + self.assertEqual(len(farray1), len(farray2), msg=msg) for f1, f2 in zip(farray1, farray2): - self.assertNear(float(f1), float(f2), err) + self.assertNear(float(f1), float(f2), err, msg=msg) def _NDArrayNear(self, ndarray1, ndarray2, err): return np.linalg.norm(ndarray1 - ndarray2) < err - def assertNDArrayNear(self, ndarray1, ndarray2, err): + def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): """Asserts that two numpy arrays have near values. Args: ndarray1: a numpy ndarray. ndarray2: a numpy ndarray. err: a float. The maximum absolute difference allowed. + msg: Optional message to report on failure. """ - self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err)) + self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) def _GetNdArray(self, a): if not isinstance(a, np.ndarray): @@ -1096,9 +1101,16 @@ class TensorFlowTestCase(googletest.TestCase): np.testing.assert_allclose( a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True) - def _assertAllCloseRecursive(self, a, b, rtol=1e-6, atol=1e-6, path=None): + def _assertAllCloseRecursive(self, + a, + b, + rtol=1e-6, + atol=1e-6, + path=None, + msg=None): path = path or [] path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "") + msg = msg if msg else "" # Check if a and/or b are namedtuples. if hasattr(a, "_asdict"): @@ -1107,18 +1119,18 @@ class TensorFlowTestCase(googletest.TestCase): b = b._asdict() a_is_dict = isinstance(a, dict) if a_is_dict != isinstance(b, dict): - raise ValueError("Can't compare dict to non-dict, a%s vs b%s." % - (path_str, path_str)) + raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % + (path_str, path_str, msg)) if a_is_dict: self.assertItemsEqual( a.keys(), b.keys(), - msg="mismatched keys: a%s has keys %s, but b%s has keys %s" % - (path_str, a.keys(), path_str, b.keys())) + msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % + (path_str, a.keys(), path_str, b.keys(), msg)) for k in a: path.append(k) self._assertAllCloseRecursive( - a[k], b[k], rtol=rtol, atol=atol, path=path) + a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) del path[-1] elif isinstance(a, (list, tuple)): # Try to directly compare a, b as ndarrays; if not work, then traverse @@ -1131,17 +1143,17 @@ class TensorFlowTestCase(googletest.TestCase): b_as_ndarray, rtol=rtol, atol=atol, - msg="Mismatched value: a%s is different from b%s." % (path_str, - path_str)) + msg="Mismatched value: a%s is different from b%s. %s" % + (path_str, path_str, msg)) except (ValueError, TypeError) as e: if len(a) != len(b): raise ValueError( - "Mismatched length: a%s has %d items, but b%s has %d items" % - (path_str, len(a), path_str, len(b))) + "Mismatched length: a%s has %d items, but b%s has %d items. %s" % + (path_str, len(a), path_str, len(b), msg)) for idx, (a_ele, b_ele) in enumerate(zip(a, b)): path.append(str(idx)) self._assertAllCloseRecursive( - a_ele, b_ele, rtol=rtol, atol=atol, path=path) + a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) del path[-1] # a and b are ndarray like objects else: @@ -1159,7 +1171,7 @@ class TensorFlowTestCase(googletest.TestCase): e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:]) raise - def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): + def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): """Asserts that two structures of numpy arrays, have near values. `a` and `b` can be arbitrarily nested structures. A layer of a nested @@ -1172,6 +1184,7 @@ class TensorFlowTestCase(googletest.TestCase): numpy `ndarray`, or any arbitrarily nested of structure of these. rtol: relative tolerance. atol: absolute tolerance. + msg: Optional message to report on failure. Raises: ValueError: if only one of `a[p]` and `b[p]` is a dict or @@ -1179,7 +1192,7 @@ class TensorFlowTestCase(googletest.TestCase): to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and `[p] = [1]['d']`, then `a[p] = (6, 7)`. """ - self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol) + self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) def assertAllCloseAccordingToType(self, a, @@ -1191,7 +1204,8 @@ class TensorFlowTestCase(googletest.TestCase): half_rtol=1e-3, half_atol=1e-3, bfloat16_rtol=1e-2, - bfloat16_atol=1e-2): + bfloat16_atol=1e-2, + msg=None): """Like assertAllClose, but also suitable for comparing fp16 arrays. In particular, the tolerance is reduced to 1e-3 if at least @@ -1208,6 +1222,7 @@ class TensorFlowTestCase(googletest.TestCase): half_atol: absolute tolerance for float16. bfloat16_rtol: relative tolerance for bfloat16. bfloat16_atol: absolute tolerance for bfloat16. + msg: Optional message to report on failure. """ a = self._GetNdArray(a) b = self._GetNdArray(b) @@ -1224,19 +1239,21 @@ class TensorFlowTestCase(googletest.TestCase): rtol = max(rtol, bfloat16_rtol) atol = max(atol, bfloat16_atol) - self.assertAllClose(a, b, rtol=rtol, atol=atol) + self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) - def assertAllEqual(self, a, b): + def assertAllEqual(self, a, b, msg=None): """Asserts that two numpy arrays have the same values. Args: a: the expected numpy ndarray or anything can be converted to one. b: the actual numpy ndarray or anything can be converted to one. + msg: Optional message to report on failure. """ + msg = msg if msg else "" a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." % - (a.shape, b.shape)) + self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." + " %s" % (a.shape, b.shape, msg)) same = (a == b) if a.dtype == np.float32 or a.dtype == np.float64: @@ -1253,7 +1270,7 @@ class TensorFlowTestCase(googletest.TestCase): x, y = a, b print("not equal lhs = ", x) print("not equal rhs = ", y) - np.testing.assert_array_equal(a, b) + np.testing.assert_array_equal(a, b, err_msg=msg) # pylint: disable=g-doc-return-or-yield @contextlib.contextmanager @@ -1303,12 +1320,13 @@ class TensorFlowTestCase(googletest.TestCase): return self.assertRaisesWithPredicateMatch(errors.OpError, expected_err_re_or_predicate) - def assertShapeEqual(self, np_array, tf_tensor): + def assertShapeEqual(self, np_array, tf_tensor, msg=None): """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. Args: np_array: A Numpy ndarray or Numpy scalar. tf_tensor: A Tensor. + msg: Optional message to report on failure. Raises: TypeError: If the arguments have the wrong type. @@ -1317,19 +1335,21 @@ class TensorFlowTestCase(googletest.TestCase): raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") if not isinstance(tf_tensor, ops.Tensor): raise TypeError("tf_tensor must be a Tensor") - self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list()) + self.assertAllEqual( + np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) - def assertDeviceEqual(self, device1, device2): + def assertDeviceEqual(self, device1, device2, msg=None): """Asserts that the two given devices are the same. Args: device1: A string device name or TensorFlow `DeviceSpec` object. device2: A string device name or TensorFlow `DeviceSpec` object. + msg: Optional message to report on failure. """ device1 = pydev.canonical_name(device1) device2 = pydev.canonical_name(device2) - self.assertEqual(device1, device2, - "Devices %s and %s are not equal" % (device1, device2)) + self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" % + (device1, device2, msg)) # Fix Python 3 compatibility issues if six.PY3: |