aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-10-05 16:47:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 16:52:24 -0700
commit29af23aeadd1d6fccbfa4223b58dad8f5b8df4f8 (patch)
treea3648b6f06e0cc2d672bdc4084b804d9337c1e4e /tensorflow/python/util
parent1daaf0fabee1c59af00e14f358d08ac9f5390b9f (diff)
Fix api_compatibility_test diff for large files. assertEqual might be applied
instead of assertMultiLineEqual if input is too large (https://bugs.python.org/issue11763). This change is switching to use unified_diff in that case. PiperOrigin-RevId: 215987656
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/protobuf/compare.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index a0e6bf65cf..3a3af4bffa 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -63,6 +63,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import difflib
import six
@@ -101,10 +102,19 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva
if normalize_numbers:
NormalizeNumberFields(pb)
- self.assertMultiLineEqual(
- text_format.MessageToString(a, descriptor_pool=pool),
- text_format.MessageToString(b, descriptor_pool=pool),
- msg=msg)
+ a_str = text_format.MessageToString(a, descriptor_pool=pool)
+ b_str = text_format.MessageToString(b, descriptor_pool=pool)
+
+ # Some Python versions would perform regular diff instead of multi-line
+ # diff if string is longer than 2**16. We substitute this behavior
+ # with a call to unified_diff instead to have easier-to-read diffs.
+ # For context, see: https://bugs.python.org/issue11763.
+ if len(a_str) < 2**16 and len(b_str) < 2**16:
+ self.assertMultiLineEqual(a_str, b_str, msg=msg)
+ else:
+ diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True),
+ b_str.splitlines(True)))
+ self.fail('%s : %s' % (msg, diff))
def NormalizeNumberFields(pb):