aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/protobuf/compare_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/util/protobuf/compare_test.py')
-rw-r--r--tensorflow/python/util/protobuf/compare_test.py652
1 files changed, 652 insertions, 0 deletions
diff --git a/tensorflow/python/util/protobuf/compare_test.py b/tensorflow/python/util/protobuf/compare_test.py
new file mode 100644
index 0000000000..9a03d123ae
--- /dev/null
+++ b/tensorflow/python/util/protobuf/compare_test.py
@@ -0,0 +1,652 @@
+#!/usr/bin/python2.4
+
+"""Tests for python.util.protobuf.compare."""
+
+import copy
+import re
+import textwrap
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.util.protobuf import compare
+from tensorflow.python.util.protobuf import compare_test_pb2
+
+from google.protobuf import text_format
+
+
+def LargePbs(*args):
+ """Converts ASCII string Large PBs to messages."""
+ pbs = []
+ for arg in args:
+ pb = compare_test_pb2.Large()
+ text_format.Merge(arg, pb)
+ pbs.append(pb)
+
+ return pbs
+
+
+class Proto2CmpTest(googletest.TestCase):
+
+ def assertGreater(self, a, b):
+ """Asserts that Proto2Cmp says a > b."""
+ a, b = LargePbs(a, b)
+ googletest.TestCase.assertGreater(self, compare.Proto2Cmp(a, b), 0)
+ googletest.TestCase.assertLess(self, compare.Proto2Cmp(b, a), 0)
+
+ def assertEquals(self, a, b):
+ """Asserts that Proto2Cmp says a == b."""
+ a, b = LargePbs(a, b)
+ googletest.TestCase.assertEquals(self, compare.Proto2Cmp(a, b), 0)
+
+ def testPrimitives(self):
+ googletest.TestCase.assertEqual(self, 0, compare.Proto2Cmp('a', 'a'))
+ googletest.TestCase.assertLess(self, 0, compare.Proto2Cmp('b', 'a'))
+
+ pb = compare_test_pb2.Large()
+ googletest.TestCase.assertEquals(self, cmp('a', pb), compare.Proto2Cmp('a', pb))
+ googletest.TestCase.assertEqual(self, cmp(pb, 'a'), compare.Proto2Cmp(pb, 'a'))
+
+ def testEmpty(self):
+ self.assertEquals('', '')
+
+ def testPrimitiveFields(self):
+ self.assertGreater('string_: "a"', '')
+ self.assertEquals('string_: "a"', 'string_: "a"')
+ self.assertGreater('string_: "b"', 'string_: "a"')
+ self.assertGreater('string_: "ab"', 'string_: "aa"')
+
+ self.assertGreater('int64_: 0', '')
+ self.assertEquals('int64_: 0', 'int64_: 0')
+ self.assertGreater('int64_: -1', '')
+ self.assertGreater('int64_: 1', 'int64_: 0')
+ self.assertGreater('int64_: 0', 'int64_: -1')
+
+ self.assertGreater('float_: 0.0', '')
+ self.assertEquals('float_: 0.0', 'float_: 0.0')
+ self.assertGreater('float_: -0.1', '')
+ self.assertGreater('float_: 3.14', 'float_: 0')
+ self.assertGreater('float_: 0', 'float_: -0.1')
+ self.assertEquals('float_: -0.1', 'float_: -0.1')
+
+ self.assertGreater('bool_: true', '')
+ self.assertGreater('bool_: false', '')
+ self.assertGreater('bool_: true', 'bool_: false')
+ self.assertEquals('bool_: false', 'bool_: false')
+ self.assertEquals('bool_: true', 'bool_: true')
+
+ self.assertGreater('enum_: A', '')
+ self.assertGreater('enum_: B', 'enum_: A')
+ self.assertGreater('enum_: C', 'enum_: B')
+ self.assertEquals('enum_: C', 'enum_: C')
+
+ def testRepeatedPrimitives(self):
+ self.assertGreater('int64s: 0', '')
+ self.assertEquals('int64s: 0', 'int64s: 0')
+ self.assertGreater('int64s: 1', 'int64s: 0')
+ self.assertGreater('int64s: 0 int64s: 0', '')
+ self.assertGreater('int64s: 0 int64s: 0', 'int64s: 0')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0')
+ self.assertGreater('int64s: 0 int64s: 1', 'int64s: 0')
+ self.assertGreater('int64s: 1', 'int64s: 0 int64s: 2')
+ self.assertGreater('int64s: 2 int64s: 0', 'int64s: 1')
+ self.assertEquals('int64s: 0 int64s: 0', 'int64s: 0 int64s: 0')
+ self.assertEquals('int64s: 0 int64s: 1', 'int64s: 0 int64s: 1')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0 int64s: 0')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0 int64s: 1')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0 int64s: 2')
+ self.assertGreater('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0')
+ self.assertGreater('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0 int64s: 2')
+
+ def testMessage(self):
+ self.assertGreater('small <>', '')
+ self.assertEquals('small <>', 'small <>')
+ self.assertGreater('small < strings: "a" >', '')
+ self.assertGreater('small < strings: "a" >', 'small <>')
+ self.assertEquals('small < strings: "a" >', 'small < strings: "a" >')
+ self.assertGreater('small < strings: "b" >', 'small < strings: "a" >')
+ self.assertGreater('small < strings: "a" strings: "b" >',
+ 'small < strings: "a" >')
+
+ self.assertGreater('string_: "a"', 'small <>')
+ self.assertGreater('string_: "a"', 'small < strings: "b" >')
+ self.assertGreater('string_: "a"', 'small < strings: "b" strings: "c" >')
+ self.assertGreater('string_: "a" small <>', 'small <>')
+ self.assertGreater('string_: "a" small <>', 'small < strings: "b" >')
+ self.assertEquals('string_: "a" small <>', 'string_: "a" small <>')
+ self.assertGreater('string_: "a" small < strings: "a" >',
+ 'string_: "a" small <>')
+ self.assertEquals('string_: "a" small < strings: "a" >',
+ 'string_: "a" small < strings: "a" >')
+ self.assertGreater('string_: "a" small < strings: "a" >',
+ 'int64_: 1 small < strings: "a" >')
+ self.assertGreater('string_: "a" small < strings: "a" >', 'int64_: 1')
+ self.assertGreater('string_: "a"', 'int64_: 1 small < strings: "a" >')
+ self.assertGreater('string_: "a" int64_: 0 small < strings: "a" >',
+ 'int64_: 1 small < strings: "a" >')
+ self.assertGreater('string_: "a" int64_: 1 small < strings: "a" >',
+ 'string_: "a" int64_: 0 small < strings: "a" >')
+ self.assertEquals('string_: "a" int64_: 0 small < strings: "a" >',
+ 'string_: "a" int64_: 0 small < strings: "a" >')
+
+ def testNestedMessage(self):
+ self.assertGreater('medium <>', '')
+ self.assertEquals('medium <>', 'medium <>')
+ self.assertGreater('medium < smalls <> >', 'medium <>')
+ self.assertEquals('medium < smalls <> >', 'medium < smalls <> >')
+ self.assertGreater('medium < smalls <> smalls <> >', 'medium < smalls <> >')
+ self.assertEquals('medium < smalls <> smalls <> >',
+ 'medium < smalls <> smalls <> >')
+
+ self.assertGreater('medium < int32s: 0 >', 'medium < smalls <> >')
+
+ self.assertGreater('medium < smalls < strings: "a"> >',
+ 'medium < smalls <> >')
+
+ def testTagOrder(self):
+ """Tests that different fields are ordered by tag number.
+
+ For reference, here are the relevant tag numbers from compare_test.proto:
+ optional string string_ = 1;
+ optional int64 int64_ = 2;
+ optional float float_ = 3;
+ optional Small small = 8;
+ optional Medium medium = 7;
+ optional Small small = 8;
+ """
+ self.assertGreater('string_: "a" ',
+ ' int64_: 1 ')
+ self.assertGreater('string_: "a" int64_: 2 ',
+ ' int64_: 1 ')
+ self.assertGreater('string_: "b" int64_: 1 ',
+ 'string_: "a" int64_: 2 ')
+ self.assertEquals( 'string_: "a" int64_: 1 ',
+ 'string_: "a" int64_: 1 ')
+ self.assertGreater('string_: "a" int64_: 1 float_: 0.0',
+ 'string_: "a" int64_: 1 ')
+ self.assertEquals( 'string_: "a" int64_: 1 float_: 0.0',
+ 'string_: "a" int64_: 1 float_: 0.0')
+ self.assertGreater('string_: "a" int64_: 1 float_: 0.1',
+ 'string_: "a" int64_: 1 float_: 0.0')
+ self.assertGreater('string_: "a" int64_: 2 float_: 0.0',
+ 'string_: "a" int64_: 1 float_: 0.1')
+ self.assertGreater('string_: "a" ',
+ ' int64_: 1 float_: 0.1')
+ self.assertGreater('string_: "a" float_: 0.0',
+ ' int64_: 1 ')
+ self.assertGreater('string_: "b" float_: 0.0',
+ 'string_: "a" int64_: 1 ')
+
+ self.assertGreater('string_: "a"',
+ 'small < strings: "a" >')
+ self.assertGreater('string_: "a" small < strings: "a" >',
+ 'small < strings: "b" >')
+ self.assertGreater('string_: "a" small < strings: "b" >',
+ 'string_: "a" small < strings: "a" >')
+ self.assertEquals('string_: "a" small < strings: "a" >',
+ 'string_: "a" small < strings: "a" >')
+
+ self.assertGreater('string_: "a" medium <>',
+ 'string_: "a" small < strings: "a" >')
+ self.assertGreater('string_: "a" medium < smalls <> >',
+ 'string_: "a" small < strings: "a" >')
+ self.assertGreater('medium <>', 'small < strings: "a" >')
+ self.assertGreater('medium <> small <>', 'small < strings: "a" >')
+ self.assertGreater('medium < smalls <> >', 'small < strings: "a" >')
+ self.assertGreater('medium < smalls < strings: "a" > >',
+ 'small < strings: "b" >')
+
+
+class NormalizeRepeatedFieldsTest(googletest.TestCase):
+
+ def assertNormalizes(self, orig, expected_no_dedupe, expected_dedupe):
+ """Checks NormalizeRepeatedFields(orig) against the two expected results."""
+ orig, expected_no_dedupe, expected_dedupe = LargePbs(
+ orig, expected_no_dedupe, expected_dedupe)
+
+ actual = compare.NormalizeRepeatedFields(copy.deepcopy(orig), dedupe=False)
+ self.assertEqual(expected_no_dedupe, actual)
+
+ actual = compare.NormalizeRepeatedFields(copy.deepcopy(orig), dedupe=True)
+ self.assertEqual(expected_dedupe, actual)
+
+ def testIgnoreNonRepeatedFields(self):
+ orig = """string_: "a" int64_: 1 float_: 0.1 bool_: true enum_: A
+ medium: {} small: {}"""
+ self.assertNormalizes(orig, orig, orig)
+
+ def testRepeatedPrimitive(self):
+ self.assertNormalizes('int64s: 3 int64s: -1 int64s: 2 int64s: -1 int64s: 3',
+ 'int64s: -1 int64s: -1 int64s: 2 int64s: 3 int64s: 3',
+ 'int64s: -1 int64s: 2 int64s: 3')
+
+ def testRepeatedMessage(self):
+ self.assertNormalizes("""medium: { smalls: { strings: "c" }
+ smalls: { strings: "a" }
+ smalls: { strings: "b" }
+ smalls: { strings: "a" }
+ smalls: { strings: "c" } }
+ """,
+ """medium: { smalls: { strings: "a" }
+ smalls: { strings: "a" }
+ smalls: { strings: "b" }
+ smalls: { strings: "c" }
+ smalls: { strings: "c" } }
+ """,
+ """medium: { smalls: { strings: "a" }
+ smalls: { strings: "b" }
+ smalls: { strings: "c" } }
+ """)
+
+ def testNestedRepeatedGroup(self):
+ self.assertNormalizes("""medium { GroupA { GroupB { strings: "c" }
+ GroupB { strings: "a" }
+ GroupB { strings: "b" }
+ GroupB { strings: "a" }
+ GroupB { strings: "c" } } }
+ """,
+ """medium { GroupA { GroupB { strings: "a" }
+ GroupB { strings: "a" }
+ GroupB { strings: "b" }
+ GroupB { strings: "c" }
+ GroupB { strings: "c" } } }
+ """,
+ """medium { GroupA { GroupB { strings: "a" }
+ GroupB { strings: "b" }
+ GroupB { strings: "c" } } }
+ """)
+
+ def testMapNormalizes(self):
+ self.assertNormalizes(
+ """with_map: { value_message: { key: 2, value: { strings: "k2v1",
+ strings: "k2v2",
+ strings: "k2v1" } },
+ value_message: { key: 1, value: { strings: "k1v2",
+ strings: "k1v1" } } }
+ """,
+ """with_map: { value_message: { key: 1, value: { strings: "k1v1",
+ strings: "k1v2" } },
+ value_message: { key: 2, value: { strings: "k2v1",
+ strings: "k2v1",
+ strings: "k2v2" } } }
+ """,
+ """with_map: { value_message: { key: 1, value: { strings: "k1v1",
+ strings: "k1v2" } },
+ value_message: { key: 2, value: { strings: "k2v1",
+ strings: "k2v2" } } }
+ """)
+
+
+class NormalizeNumbersTest(googletest.TestCase):
+ """Tests for NormalizeNumberFields()."""
+
+ def testNormalizesInts(self):
+ pb = compare_test_pb2.Large()
+ pb.int64_ = 4
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64_, long))
+
+ pb.int64_ = 4L
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64_, long))
+
+ pb.int64_ = 9999999999999999L
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64_, long))
+
+ def testNormalizesRepeatedInts(self):
+ pb = compare_test_pb2.Large()
+ pb.int64s.extend([1L, 400, 999999999999999L])
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64s[0], long))
+ self.assertTrue(isinstance(pb.int64s[1], long))
+ self.assertTrue(isinstance(pb.int64s[2], long))
+
+ def testNormalizesFloats(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.float_ = 1.2314352351231
+ pb2 = compare_test_pb2.Large()
+ pb2.float_ = 1.231435
+ self.assertNotEqual(pb1.float_, pb2.float_)
+ compare.NormalizeNumberFields(pb1)
+ compare.NormalizeNumberFields(pb2)
+ self.assertEqual(pb1.float_, pb2.float_)
+
+ def testNormalizesRepeatedFloats(self):
+ pb = compare_test_pb2.Large()
+ pb.medium.floats.extend([0.111111111, 0.111111])
+ compare.NormalizeNumberFields(pb)
+ for value in pb.medium.floats:
+ self.assertAlmostEqual(0.111111, value)
+
+ def testNormalizesDoubles(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.double_ = 1.2314352351231
+ pb2 = compare_test_pb2.Large()
+ pb2.double_ = 1.2314352
+ self.assertNotEqual(pb1.double_, pb2.double_)
+ compare.NormalizeNumberFields(pb1)
+ compare.NormalizeNumberFields(pb2)
+ self.assertEqual(pb1.double_, pb2.double_)
+
+ def testNormalizesMaps(self):
+ pb = compare_test_pb2.WithMap()
+ pb.value_message[4].strings.extend(['a', 'b', 'c'])
+ pb.value_string['d'] = 'e'
+ compare.NormalizeNumberFields(pb)
+
+
+class AssertTest(googletest.TestCase):
+ """Tests both assertProto2Equal() and assertProto2SameElements()."""
+ def assertProto2Equal(self, a, b, **kwargs):
+ if isinstance(a, basestring) and isinstance(b, basestring):
+ a, b = LargePbs(a, b)
+ compare.assertProto2Equal(self, a, b, **kwargs)
+
+ def assertProto2SameElements(self, a, b, **kwargs):
+ if isinstance(a, basestring) and isinstance(b, basestring):
+ a, b = LargePbs(a, b)
+ compare.assertProto2SameElements(self, a, b, **kwargs)
+
+ def assertAll(self, a, **kwargs):
+ """Checks that all possible asserts pass."""
+ self.assertProto2Equal(a, a, **kwargs)
+ self.assertProto2SameElements(a, a, number_matters=False, **kwargs)
+ self.assertProto2SameElements(a, a, number_matters=True, **kwargs)
+
+ def assertSameNotEqual(self, a, b):
+ """Checks that assertProto2SameElements() passes with number_matters=False
+ and number_matters=True but not assertProto2Equal().
+ """
+ self.assertProto2SameElements(a, b, number_matters=False)
+ self.assertProto2SameElements(a, b, number_matters=True)
+ self.assertRaises(AssertionError, self.assertProto2Equal, a, b)
+
+ def assertSameExceptNumber(self, a, b):
+ """Checks that assertProto2SameElements() passes with number_matters=False
+ but not number_matters=True or assertProto2Equal().
+ """
+ self.assertProto2SameElements(a, b, number_matters=False)
+ self.assertRaises(AssertionError, self.assertProto2SameElements, a, b,
+ number_matters=True)
+ self.assertRaises(AssertionError, self.assertProto2Equal, a, b)
+
+ def assertNone(self, a, b, message, **kwargs):
+ """Checks that all possible asserts fail with the given message."""
+ message = re.escape(textwrap.dedent(message))
+ self.assertRaisesRegexp(AssertionError, message,
+ self.assertProto2SameElements, a, b,
+ number_matters=False, **kwargs)
+ self.assertRaisesRegexp(AssertionError, message,
+ self.assertProto2SameElements, a, b,
+ number_matters=True, **kwargs)
+ self.assertRaisesRegexp(AssertionError, message,
+ self.assertProto2Equal, a, b, **kwargs)
+
+ def testCheckInitialized(self):
+ # neither is initialized
+ a = compare_test_pb2.Labeled()
+ a.optional = 1
+ self.assertNone(a, a, 'Initialization errors: ', check_initialized=True)
+ self.assertAll(a, check_initialized=False)
+
+ # a is initialized, b isn't
+ b = copy.deepcopy(a)
+ a.required = 2
+ self.assertNone(a, b, 'Initialization errors: ', check_initialized=True)
+ self.assertNone(a, b,
+ """
+ - required: 2
+ optional: 1
+ """,
+ check_initialized=False)
+
+ # both are initialized
+ a = compare_test_pb2.Labeled()
+ a.required = 2
+ self.assertAll(a, check_initialized=True)
+ self.assertAll(a, check_initialized=False)
+
+ b = copy.deepcopy(a)
+ b.required = 3
+ message = """
+ - required: 2
+ ? ^
+ + required: 3
+ ? ^
+ """
+ self.assertNone(a, b, message, check_initialized=True)
+ self.assertNone(a, b, message, check_initialized=False)
+
+ def testAssertEqualWithStringArg(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ compare.assertProto2Equal(
+ self,
+ """
+ string_: 'abc'
+ float_: 1.234
+ """,
+ pb)
+
+ def testAssertSameElementsWithStringArg(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ pb.int64s.extend([7, 3, 5])
+ compare.assertProto2SameElements(
+ self,
+ """
+ string_: 'abc'
+ float_: 1.234
+ int64s: 3
+ int64s: 7
+ int64s: 5
+ """,
+ pb)
+
+ def testProto2ContainsString(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ pb.small.strings.append('xyz')
+ compare.assertProto2Contains(
+ self,
+ """
+ small {
+ strings: "xyz"
+ }
+ """,
+ pb)
+
+ def testProto2ContainsProto(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ pb.small.strings.append('xyz')
+ pb2 = compare_test_pb2.Large()
+ pb2.small.strings.append('xyz')
+ compare.assertProto2Contains(
+ self, pb2, pb)
+
+ def testNormalizesNumbers(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.int64_ = 4
+ pb2 = compare_test_pb2.Large()
+ pb2.int64_ = 4L
+ compare.assertProto2Equal(self, pb1, pb2)
+
+ def testNormalizesFloat(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.double_ = 4.0
+ pb2 = compare_test_pb2.Large()
+ pb2.double_ = 4L
+ compare.assertProto2Equal(self, pb1, pb2, normalize_numbers=True)
+
+ pb1 = compare_test_pb2.Medium()
+ pb1.floats.extend([4.0, 6.0])
+ pb2 = compare_test_pb2.Medium()
+ pb2.floats.extend([6L, 4L])
+ compare.assertProto2SameElements(self, pb1, pb2, normalize_numbers=True)
+
+ def testPrimitives(self):
+ self.assertAll('string_: "x"')
+ self.assertNone('string_: "x"',
+ 'string_: "y"',
+ """
+ - string_: "x"
+ ? ^
+ + string_: "y"
+ ? ^
+ """)
+
+ def testRepeatedPrimitives(self):
+ self.assertAll('int64s: 0 int64s: 1')
+
+ self.assertSameNotEqual('int64s: 0 int64s: 1', 'int64s: 1 int64s: 0')
+ self.assertSameNotEqual('int64s: 0 int64s: 1 int64s: 2',
+ 'int64s: 2 int64s: 1 int64s: 0')
+
+ self.assertSameExceptNumber('int64s: 0', 'int64s: 0 int64s: 0')
+ self.assertSameExceptNumber('int64s: 0 int64s: 1',
+ 'int64s: 1 int64s: 0 int64s: 1')
+
+ self.assertNone('int64s: 0',
+ 'int64s: 0 int64s: 2',
+ """
+ int64s: 0
+ + int64s: 2
+ """)
+ self.assertNone('int64s: 0 int64s: 1',
+ 'int64s: 0 int64s: 2',
+ """
+ int64s: 0
+ - int64s: 1
+ ? ^
+ + int64s: 2
+ ? ^
+ """)
+
+ def testMessage(self):
+ self.assertAll('medium: {}')
+ self.assertAll('medium: { smalls: {} }')
+ self.assertAll('medium: { int32s: 1 smalls: {} }')
+ self.assertAll('medium: { smalls: { strings: "x" } }')
+ self.assertAll('medium: { smalls: { strings: "x" } } small: { strings: "y" }')
+
+ self.assertSameNotEqual(
+ 'medium: { smalls: { strings: "x" strings: "y" } }',
+ 'medium: { smalls: { strings: "y" strings: "x" } }')
+ self.assertSameNotEqual(
+ 'medium: { smalls: { strings: "x" } smalls: { strings: "y" } }',
+ 'medium: { smalls: { strings: "y" } smalls: { strings: "x" } }')
+
+ self.assertSameExceptNumber(
+ 'medium: { smalls: { strings: "x" strings: "y" strings: "x" } }',
+ 'medium: { smalls: { strings: "y" strings: "x" } }')
+ self.assertSameExceptNumber(
+ 'medium: { smalls: { strings: "x" } int32s: 0 }',
+ 'medium: { int32s: 0 smalls: { strings: "x" } int32s: 0 }')
+
+ self.assertNone('medium: {}',
+ 'medium: { smalls: { strings: "x" } }',
+ """
+ medium {
+ + smalls {
+ + strings: "x"
+ + }
+ }
+ """)
+ self.assertNone('medium: { smalls: { strings: "x" } }',
+ 'medium: { smalls: {} }',
+ """
+ medium {
+ smalls {
+ - strings: "x"
+ }
+ }
+ """)
+ self.assertNone('medium: { int32s: 0 }',
+ 'medium: { int32s: 1 }',
+ """
+ medium {
+ - int32s: 0
+ ? ^
+ + int32s: 1
+ ? ^
+ }
+ """)
+
+ def testMsgPassdown(self):
+ self.assertRaisesRegexp(AssertionError, 'test message passed down',
+ self.assertProto2Equal,
+ 'medium: {}',
+ 'medium: { smalls: { strings: "x" } }',
+ msg='test message passed down')
+
+ def testRepeatedMessage(self):
+ self.assertAll('medium: { smalls: {} smalls: {} }')
+ self.assertAll('medium: { smalls: { strings: "x" } } medium: {}')
+ self.assertAll('medium: { smalls: { strings: "x" } } medium: { int32s: 0 }')
+ self.assertAll('medium: { smalls: {} smalls: { strings: "x" } } small: {}')
+
+ self.assertSameNotEqual('medium: { smalls: { strings: "x" } smalls: {} }',
+ 'medium: { smalls: {} smalls: { strings: "x" } }')
+
+ self.assertSameExceptNumber('medium: { smalls: {} }',
+ 'medium: { smalls: {} smalls: {} }')
+ self.assertSameExceptNumber('medium: { smalls: {} smalls: {} } medium: {}',
+ 'medium: {} medium: {} medium: { smalls: {} }')
+ self.assertSameExceptNumber(
+ 'medium: { smalls: { strings: "x" } smalls: {} }',
+ 'medium: { smalls: {} smalls: { strings: "x" } smalls: {} }')
+
+ self.assertNone('medium: {}',
+ 'medium: {} medium { smalls: {} }',
+ """
+ medium {
+ + smalls {
+ + }
+ }
+ """)
+ self.assertNone('medium: { smalls: {} smalls: { strings: "x" } }',
+ 'medium: { smalls: {} smalls: { strings: "y" } }',
+ """
+ medium {
+ smalls {
+ }
+ smalls {
+ - strings: "x"
+ ? ^
+ + strings: "y"
+ ? ^
+ }
+ }
+ """)
+
+
+class MixinTests(compare.Proto2Assertions, googletest.TestCase):
+
+ def testAssertEqualWithStringArg(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ self.assertProto2Equal(
+ """
+ string_: 'abc'
+ float_: 1.234
+ """,
+ pb)
+
+ def testAssertSameElements(self):
+ a = compare_test_pb2.Large()
+ a.string_ = 'abc'
+ a.float_ = 1.234
+ a.int64s[:] = [4, 3, 2]
+ b = compare_test_pb2.Large()
+ b.CopyFrom(a)
+ b.int64s[:] = [2, 4, 3]
+ self.assertProto2SameElements(a, b)
+
+
+if __name__ == '__main__':
+ googletest.main()