From 0e2089c775ce1e19cc9429e496d84952e11c907c Mon Sep 17 00:00:00 2001 From: cyyber Date: Fri, 15 Dec 2017 01:01:31 +0530 Subject: Calling Keychecker before checking key in MessageMap --- python/google/protobuf/internal/containers.py | 4 +++- python/google/protobuf/internal/message_test.py | 16 ++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) (limited to 'python') diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 68be9e54..c6a3692a 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -549,10 +549,10 @@ class MessageMap(MutableMapping): self._values = {} def __getitem__(self, key): + key = self._key_checker.CheckValue(key) try: return self._values[key] except KeyError: - key = self._key_checker.CheckValue(key) new_element = self._message_descriptor._concrete_class() new_element._SetListener(self._message_listener) self._values[key] = new_element @@ -584,12 +584,14 @@ class MessageMap(MutableMapping): return default def __contains__(self, item): + item = self._key_checker.CheckValue(item) return item in self._values def __setitem__(self, key, value): raise ValueError('May not set values directly, call my_map[key].foo = 5') def __delitem__(self, key): + key = self._key_checker.CheckValue(key) del self._values[key] self._message_listener.Modified() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index a303b1aa..8dae6377 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1480,12 +1480,8 @@ class Proto3Test(BaseTestCase): submsg = msg.map_int32_foreign_message[5] self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) - # TODO(jieluo): Fix python and cpp extension diff. - if api_implementation.Type() == 'cpp': - with self.assertRaises(TypeError): - msg.map_int32_foreign_message.get('') - else: - self.assertEqual(None, msg.map_int32_foreign_message.get('')) + with self.assertRaises(TypeError): + msg.map_int32_foreign_message.get('') def testScalarMap(self): msg = map_unittest_pb2.TestMap() @@ -1695,12 +1691,8 @@ class Proto3Test(BaseTestCase): del msg2.map_int32_foreign_message[222] self.assertFalse(222 in msg2.map_int32_foreign_message) - if api_implementation.Type() == 'cpp': - with self.assertRaises(TypeError): - del msg2.map_int32_foreign_message[''] - else: - with self.assertRaises(KeyError): - del msg2.map_int32_foreign_message[''] + with self.assertRaises(TypeError): + del msg2.map_int32_foreign_message[''] def testMergeFromBadType(self): msg = map_unittest_pb2.TestMap() -- cgit v1.2.3