diff options
author | jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2014-08-12 21:10:30 +0000 |
---|---|---|
committer | jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2014-08-12 21:10:30 +0000 |
commit | bde4a3254a7de58911941b0fbf38e9dd992de973 (patch) | |
tree | 02b151c2ec6e9be2e9d5ea0efc406aabe6958ae7 /python/google/protobuf/internal | |
parent | d7339318a33c5f9e8b5dded4077223fbd4ebf229 (diff) |
down integrate python opensource to svn
Diffstat (limited to 'python/google/protobuf/internal')
34 files changed, 2823 insertions, 569 deletions
diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc new file mode 100644 index 00000000..ad6fd9c8 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation.cc @@ -0,0 +1,139 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <Python.h> + +namespace google { +namespace protobuf { +namespace python { + +// Version constant. +// This is either 0 for python, 1 for CPP V1, 2 for CPP V2. +// +// 0 is default and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +// +// 1 is set with -DPYTHON_PROTO2_CPP_IMPL_V1 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=1 +// +// 2 is set with -DPYTHON_PROTO2_CPP_IMPL_V2 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 +#ifdef PYTHON_PROTO2_CPP_IMPL_V1 +#if PY_MAJOR_VERSION >= 3 +#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3." +#endif +static int kImplVersion = 1; +#else +#ifdef PYTHON_PROTO2_CPP_IMPL_V2 +static int kImplVersion = 2; +#else +#ifdef PYTHON_PROTO2_PYTHON_IMPL +static int kImplVersion = 0; +#else + +// The defaults are set here. Python 3 uses the fast C++ APIv2 by default. +// Python 2 still uses the Python version by default until some compatibility +// issues can be worked around. +#if PY_MAJOR_VERSION >= 3 +static int kImplVersion = 2; +#else +static int kImplVersion = 0; +#endif + +#endif // PYTHON_PROTO2_PYTHON_IMPL +#endif // PYTHON_PROTO2_CPP_IMPL_V2 +#endif // PYTHON_PROTO2_CPP_IMPL_V1 + +static const char* kImplVersionName = "api_version"; + +static const char* kModuleName = "_api_implementation"; +static const char kModuleDocstring[] = +"_api_implementation is a module that exposes compile-time constants that\n" +"determine the default API implementation to use for Python proto2.\n" +"\n" +"It complements api_implementation.py by setting defaults using compile-time\n" +"constants defined in C, such that one can set defaults at compilation\n" +"(e.g. with blaze flag --copt=-DPYTHON_PROTO2_CPP_IMPL_V2)."; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + kModuleName, + kModuleDocstring, + -1, + NULL, + NULL, + NULL, + NULL, + NULL +}; +#define INITFUNC PyInit__api_implementation +#define INITFUNC_ERRORVAL NULL +#else +#define INITFUNC init_api_implementation +#define INITFUNC_ERRORVAL +#endif + +extern "C" { + PyMODINIT_FUNC INITFUNC() { +#if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&_module); +#else + PyObject *module = Py_InitModule3( + const_cast<char*>(kModuleName), + NULL, + const_cast<char*>(kModuleDocstring)); +#endif + if (module == NULL) { + return INITFUNC_ERRORVAL; + } + + // Adds the module variable "api_version". + if (PyModule_AddIntConstant( + module, + const_cast<char*>(kImplVersionName), + kImplVersion)) +#if PY_MAJOR_VERSION < 3 + return; +#else + { Py_DECREF(module); return NULL; } + + return module; +#endif + } +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index ce02a329..cbb85747 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -28,41 +28,44 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Determine which implementation of the protobuf API is used in this process. """ -This module is the central entity that determines which implementation of the -API is used. -""" - -__author__ = 'petar@google.com (Petar Petrov)' import os +import sys + +try: + # pylint: disable=g-import-not-at-top + from google.protobuf.internal import _api_implementation + # The compile-time constants in the _api_implementation module can be used to + # switch to a certain implementation of the Python API at build time. + _api_version = _api_implementation.api_version + del _api_implementation +except ImportError: + _api_version = 0 + +_default_implementation_type = ( + 'python' if _api_version == 0 else 'cpp') +_default_version_str = ( + '1' if _api_version <= 1 else '2') + # This environment variable can be used to switch to a certain implementation -# of the Python API. Right now only 'python' and 'cpp' are valid values. Any -# other value will be ignored. +# of the Python API, overriding the compile-time constants in the +# _api_implementation module. Right now only 'python' and 'cpp' are valid +# values. Any other value will be ignored. _implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', - 'python') - + _default_implementation_type) if _implementation_type != 'python': - # For now, by default use the pure-Python implementation. - # The code below checks if the C extension is available and - # uses it if it is available. _implementation_type = 'cpp' - ## Determine automatically which implementation to use. - #try: - # from google.protobuf.internal import cpp_message - # _implementation_type = 'cpp' - #except ImportError, e: - # _implementation_type = 'python' - # This environment variable can be used to switch between the two -# 'cpp' implementations. Right now only 1 and 2 are valid values. Any -# other value will be ignored. +# 'cpp' implementations, overriding the compile-time constants in the +# _api_implementation module. Right now only 1 and 2 are valid values. Any other +# value will be ignored. _implementation_version_str = os.getenv( 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', - '1') - + _default_version_str) if _implementation_version_str not in ('1', '2'): raise ValueError( @@ -70,11 +73,9 @@ if _implementation_version_str not in ('1', '2'): _implementation_version_str + "' (supported versions: 1, 2)" ) - _implementation_version = int(_implementation_version_str) - # Usage of this function is discouraged. Clients shouldn't care which # implementation of the API is in use. Note that there is no guarantee # that differences between APIs will be maintained. @@ -82,6 +83,7 @@ _implementation_version = int(_implementation_version_str) def Type(): return _implementation_type + # See comment on 'Type' above. def Version(): return _implementation_version diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py new file mode 100644 index 00000000..b2b41284 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation_default_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test that the api_implementation defaults are what we expect.""" + +import os +import sys +# Clear environment implementation settings before the google3 imports. +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None) +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None) + +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation + + +class ApiImplementationDefaultTest(basetest.TestCase): + + if sys.version_info.major <= 2: + + def testThatPythonIsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('python', api_implementation.Type()) + + else: + + def testThatCppApiV2IsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 34b35f8a..5797e81b 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -108,15 +108,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def append(self, value): """Appends an item to the list. Similar to list.append().""" - self._type_checker.CheckValue(value) - self._values.append(value) + self._values.append(self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() def insert(self, key, value): """Inserts the item at the specified position. Similar to list.insert().""" - self._type_checker.CheckValue(value) - self._values.insert(key, value) + self._values.insert(key, self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() @@ -127,8 +125,7 @@ class RepeatedScalarFieldContainer(BaseContainer): new_values = [] for elem in elem_seq: - self._type_checker.CheckValue(elem) - new_values.append(elem) + new_values.append(self._type_checker.CheckValue(elem)) self._values.extend(new_values) self._message_listener.Modified() @@ -146,9 +143,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def __setitem__(self, key, value): """Sets the item on the specified position.""" - self._type_checker.CheckValue(value) - self._values[key] = value - self._message_listener.Modified() + if isinstance(key, slice): # PY3 + if key.step is not None: + raise ValueError('Extended slices not supported') + self.__setslice__(key.start, key.stop, value) + else: + self._values[key] = self._type_checker.CheckValue(value) + self._message_listener.Modified() def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" @@ -158,8 +159,7 @@ class RepeatedScalarFieldContainer(BaseContainer): """Sets the subset of items from between the specified indices.""" new_values = [] for value in values: - self._type_checker.CheckValue(value) - new_values.append(value) + new_values.append(self._type_checker.CheckValue(value)) self._values[start:stop] = new_values self._message_listener.Modified() diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py index 23ab9ba4..8eb38ca4 100755 --- a/python/google/protobuf/internal/cpp_message.py +++ b/python/google/protobuf/internal/cpp_message.py @@ -610,7 +610,7 @@ def _AddMessageMethods(message_descriptor, cls): return self._cmsg.FindInitializationErrors() def __str__(self): - return self._cmsg.DebugString() + return str(self._cmsg) def __eq__(self, other): if self is other: diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index cb6f5729..651ee0d4 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for decoding protocol buffer primitives. This code is very similar to encoder.py -- read the docs for that module first. @@ -81,6 +85,8 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -98,7 +104,7 @@ _NAN = _POS_INF * 0 _DecodeError = message.DecodeError -def _VarintDecoder(mask): +def _VarintDecoder(mask, result_type): """Return an encoder for a basic varint value (does not include tag). Decoded values will be bitwise-anded with the given mask before being @@ -109,15 +115,18 @@ def _VarintDecoder(mask): """ local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: @@ -125,15 +134,17 @@ def _VarintDecoder(mask): return DecodeVarint -def _SignedVarintDecoder(mask): +def _SignedVarintDecoder(mask, result_type): """Like _VarintDecoder() but decodes signed values.""" local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -142,19 +153,23 @@ def _SignedVarintDecoder(mask): result |= ~mask else: result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: raise _DecodeError('Too many bytes when decoding varint.') return DecodeVarint +# We force 32-bit values to int and 64-bit values to long to make +# alternate implementations where the distinction is more significant +# (e.g. the C++ implementation) simpler. -_DecodeVarint = _VarintDecoder((1 << 64) - 1) -_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) +_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) # Use these versions for values which must be limited to 32 bits. -_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) -_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) def ReadTag(buffer, pos): @@ -168,8 +183,10 @@ def ReadTag(buffer, pos): use that, but not in Python. """ + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes start = pos - while ord(buffer[pos]) & 0x80: + while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80: pos += 1 pos += 1 return (buffer[start:pos], pos) @@ -284,6 +301,7 @@ def _FloatDecoder(): """ local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 def InnerDecode(buffer, pos): # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign @@ -294,13 +312,17 @@ def _FloatDecoder(): # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. # To avoid that, we parse it specially. - if ((float_bytes[3] in '\x7F\xFF') - and (float_bytes[2] >= '\x80')): + if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25 +##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF') + and (float_bytes[2:3] >= b('\x80'))): ##PY25 +##!PY25 and (float_bytes[2:3] >= b'\x80')): # If at least one significand bit is set... - if float_bytes[0:3] != '\x00\x00\x80': + if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25 +##!PY25 if float_bytes[0:3] != b'\x00\x00\x80': return (_NAN, new_pos) # If sign bit is set... - if float_bytes[3] == '\xFF': + if float_bytes[3:4] == b('\xFF'): ##PY25 +##!PY25 if float_bytes[3:4] == b'\xFF': return (_NEG_INF, new_pos) return (_POS_INF, new_pos) @@ -319,6 +341,7 @@ def _DoubleDecoder(): """ local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 def InnerDecode(buffer, pos): # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign @@ -329,9 +352,12 @@ def _DoubleDecoder(): # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it # as inf or -inf. To avoid that, we treat it specially. - if ((double_bytes[7] in '\x7F\xFF') - and (double_bytes[6] >= '\xF0') - and (double_bytes[0:7] != '\x00\x00\x00\x00\x00\x00\xF0')): +##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF') +##!PY25 and (double_bytes[6:7] >= b'\xF0') +##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): + if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25 + and (double_bytes[6:7] >= b('\xF0')) ##PY25 + and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25 return (_NAN, new_pos) # Note that we expect someone up-stack to catch struct.error and convert @@ -342,10 +368,86 @@ def _DoubleDecoder(): return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) +def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): + enum_type = key.enum_type + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + value_start_pos = pos + (element, pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + if pos > endpoint: + if element in enum_type.values_by_number: + del value[-1] # Discard corrupt value. + else: + del message._unknown_fields[-1] + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append( + (tag_bytes, buffer[pos:new_pos])) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value_start_pos = pos + (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) + if pos > end: + raise _DecodeError('Truncated message.') + if enum_value in enum_type.values_by_number: + field_dict[key] = enum_value + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + return pos + return DecodeField + + # -------------------------------------------------------------------- -Int32Decoder = EnumDecoder = _SimpleDecoder( +Int32Decoder = _SimpleDecoder( wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) Int64Decoder = _SimpleDecoder( @@ -380,6 +482,14 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): local_DecodeVarint = _DecodeVarint local_unicode = unicode + def _ConvertToUnicode(byte_str): + try: + return local_unicode(byte_str, 'utf-8') + except UnicodeDecodeError, e: + # add more information to the error message and re-raise it. + e.reason = '%s in field: %s' % (e, key.full_name) + raise + assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, @@ -394,7 +504,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) + value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -407,7 +517,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') + field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos return DecodeField @@ -631,8 +741,10 @@ def MessageSetItemDecoder(extensions_by_number): def _SkipVarint(buffer, pos, end): """Skip a varint value. Returns the new position.""" - - while ord(buffer[pos]) & 0x80: + # Previously ord(buffer[pos]) raised IndexError when pos is out of range. + # With this code, ord(b'') raises TypeError. Both are handled in + # python_message.py to generate a 'Truncated message' error. + while ord(buffer[pos:pos+1]) & 0x80: pos += 1 pos += 1 if pos > end: @@ -699,7 +811,6 @@ def _FieldSkipper(): ] wiretype_mask = wire_format.TAG_TYPE_MASK - local_ord = ord def SkipField(buffer, pos, end, tag_bytes): """Skips a field with the specified tag. @@ -712,7 +823,7 @@ def _FieldSkipper(): """ # The wire type is always in the first byte since varints are little-endian. - wire_type = local_ord(tag_bytes[0]) & wiretype_mask + wire_type = ord(tag_bytes[0:1]) & wiretype_mask return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) return SkipField diff --git a/python/google/protobuf/internal/descriptor_cpp2_test.py b/python/google/protobuf/internal/descriptor_cpp2_test.py new file mode 100644 index 00000000..3a3ff298 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_cpp2_test.py @@ -0,0 +1,58 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.pyext behavior.""" + +__author__ = 'anuraag@google.com (Anuraag Agrawal)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.descriptor_test import * + + +class ConfirmCppApi2Test(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index d0ca7892..856f4723 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -34,13 +34,13 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' -import unittest +from google.apputils import basetest from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database -class DescriptorDatabaseTest(unittest.TestCase): +class DescriptorDatabaseTest(basetest.TestCase): def testAdd(self): db = descriptor_database.DescriptorDatabase() @@ -49,15 +49,15 @@ class DescriptorDatabaseTest(unittest.TestCase): db.Add(file_desc_proto) self.assertEquals(file_desc_proto, db.FindFileByName( - 'net/proto2/python/internal/factory_test2.proto')) + 'google/protobuf/internal/factory_test2.proto')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message')) + 'google.protobuf.python.internal.Factory2Message')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Message')) + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Enum')) + 'google.protobuf.python.internal.Factory2Enum')) self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum')) + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index a615d787..7c1ce2e1 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -34,8 +34,15 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' +import os import unittest + +from google.apputils import basetest +from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 +from google.protobuf.internal import api_implementation +from google.protobuf.internal import descriptor_pool_test1_pb2 +from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor @@ -43,7 +50,7 @@ from google.protobuf import descriptor_database from google.protobuf import descriptor_pool -class DescriptorPoolTest(unittest.TestCase): +class DescriptorPoolTest(basetest.TestCase): def setUp(self): self.pool = descriptor_pool.DescriptorPool() @@ -55,57 +62,51 @@ class DescriptorPoolTest(unittest.TestCase): self.pool.Add(self.factory_test2_fd) def testFindFileByName(self): - name1 = 'net/proto2/python/internal/factory_test1.proto' + name1 = 'google/protobuf/internal/factory_test1.proto' file_desc1 = self.pool.FindFileByName(name1) self.assertIsInstance(file_desc1, descriptor.FileDescriptor) self.assertEquals(name1, file_desc1.name) - self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) self.assertIn('Factory1Message', file_desc1.message_types_by_name) - name2 = 'net/proto2/python/internal/factory_test2.proto' + name2 = 'google/protobuf/internal/factory_test2.proto' file_desc2 = self.pool.FindFileByName(name2) self.assertIsInstance(file_desc2, descriptor.FileDescriptor) self.assertEquals(name2, file_desc2.name) - self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) def testFindFileByNameFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindFileByName('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testFindFileContainingSymbol(self): file_desc1 = self.pool.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory1Message') + 'google.protobuf.python.internal.Factory1Message') self.assertIsInstance(file_desc1, descriptor.FileDescriptor) - self.assertEquals('net/proto2/python/internal/factory_test1.proto', + self.assertEquals('google/protobuf/internal/factory_test1.proto', file_desc1.name) - self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) self.assertIn('Factory1Message', file_desc1.message_types_by_name) file_desc2 = self.pool.FindFileContainingSymbol( - 'net.proto2.python.internal.Factory2Message') + 'google.protobuf.python.internal.Factory2Message') self.assertIsInstance(file_desc2, descriptor.FileDescriptor) - self.assertEquals('net/proto2/python/internal/factory_test2.proto', + self.assertEquals('google/protobuf/internal/factory_test2.proto', file_desc2.name) - self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) def testFindFileContainingSymbolFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindFileContainingSymbol('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testFindMessageTypeByName(self): msg1 = self.pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory1Message') + 'google.protobuf.python.internal.Factory1Message') self.assertIsInstance(msg1, descriptor.Descriptor) self.assertEquals('Factory1Message', msg1.name) - self.assertEquals('net.proto2.python.internal.Factory1Message', + self.assertEquals('google.protobuf.python.internal.Factory1Message', msg1.full_name) self.assertEquals(None, msg1.containing_type) @@ -123,10 +124,10 @@ class DescriptorPoolTest(unittest.TestCase): 'nested_factory_1_enum'].enum_type) msg2 = self.pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory2Message') + 'google.protobuf.python.internal.Factory2Message') self.assertIsInstance(msg2, descriptor.Descriptor) self.assertEquals('Factory2Message', msg2.name) - self.assertEquals('net.proto2.python.internal.Factory2Message', + self.assertEquals('google.protobuf.python.internal.Factory2Message', msg2.full_name) self.assertIsNone(msg2.containing_type) @@ -143,45 +144,57 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEquals(nested_enum2, msg2.fields_by_name[ 'nested_factory_2_enum'].enum_type) - self.assertTrue(msg2.fields_by_name['int_with_default'].has_default) + self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value) self.assertEquals( 1776, msg2.fields_by_name['int_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['double_with_default'].has_default) + self.assertTrue( + msg2.fields_by_name['double_with_default'].has_default_value) self.assertEquals( 9.99, msg2.fields_by_name['double_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['string_with_default'].has_default) + self.assertTrue( + msg2.fields_by_name['string_with_default'].has_default_value) self.assertEquals( 'hello world', msg2.fields_by_name['string_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default) + self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value) self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) - self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default) + self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value) self.assertEquals( 1, msg2.fields_by_name['enum_with_default'].default_value) msg3 = self.pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Message') + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message') self.assertEquals(nested_msg2, msg3) + self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value) + self.assertEquals( + b'a\xfb\x00c', + msg2.fields_by_name['bytes_with_default'].default_value) + + self.assertEqual(1, len(msg2.oneofs)) + self.assertEqual(1, len(msg2.oneofs_by_name)) + self.assertEqual(2, len(msg2.oneofs[0].fields)) + for name in ['oneof_int', 'oneof_string']: + self.assertEqual(msg2.oneofs[0], + msg2.fields_by_name[name].containing_oneof) + self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields) + def testFindMessageTypeByNameFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindMessageTypeByName('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testFindEnumTypeByName(self): enum1 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory1Enum') + 'google.protobuf.python.internal.Factory1Enum') self.assertIsInstance(enum1, descriptor.EnumDescriptor) self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) nested_enum1 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory1Message.NestedFactory1Enum') + 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) self.assertEquals( 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) @@ -189,13 +202,13 @@ class DescriptorPoolTest(unittest.TestCase): 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) enum2 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory2Enum') + 'google.protobuf.python.internal.Factory2Enum') self.assertIsInstance(enum2, descriptor.EnumDescriptor) self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) nested_enum2 = self.pool.FindEnumTypeByName( - 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum') + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum') self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) self.assertEquals( 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) @@ -203,11 +216,8 @@ class DescriptorPoolTest(unittest.TestCase): 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) def testFindEnumTypeByNameFailure(self): - try: + with self.assertRaises(KeyError): self.pool.FindEnumTypeByName('Does not exist') - self.fail('Expected KeyError') - except KeyError: - pass def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() @@ -216,5 +226,339 @@ class DescriptorPoolTest(unittest.TestCase): db.Add(self.factory_test2_fd) self.testFindMessageTypeByName() + def testComplexNesting(self): + test1_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) + test2_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(test1_desc) + self.pool.Add(test2_desc) + TEST1_FILE.CheckFile(self, self.pool) + TEST2_FILE.CheckFile(self, self.pool) + + + +class ProtoFile(object): + + def __init__(self, name, package, messages, dependencies=None): + self.name = name + self.package = package + self.messages = messages + self.dependencies = dependencies or [] + + def CheckFile(self, test, pool): + file_desc = pool.FindFileByName(self.name) + test.assertEquals(self.name, file_desc.name) + test.assertEquals(self.package, file_desc.package) + dependencies_names = [f.name for f in file_desc.dependencies] + test.assertEqual(self.dependencies, dependencies_names) + for name, msg_type in self.messages.items(): + msg_type.CheckType(test, None, name, file_desc) + + +class EnumType(object): + + def __init__(self, values): + self.values = values + + def CheckType(self, test, msg_desc, name, file_desc): + enum_desc = msg_desc.enum_types_by_name[name] + test.assertEqual(name, enum_desc.name) + expected_enum_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_enum_full_name, enum_desc.full_name) + test.assertEqual(msg_desc, enum_desc.containing_type) + test.assertEqual(file_desc, enum_desc.file) + for index, (value, number) in enumerate(self.values): + value_desc = enum_desc.values_by_name[value] + test.assertEqual(value, value_desc.name) + test.assertEqual(index, value_desc.index) + test.assertEqual(number, value_desc.number) + test.assertEqual(enum_desc, value_desc.type) + test.assertIn(value, msg_desc.enum_values_by_name) + + +class MessageType(object): + + def __init__(self, type_dict, field_list, is_extendable=False, + extensions=None): + self.type_dict = type_dict + self.field_list = field_list + self.is_extendable = is_extendable + self.extensions = extensions or [] + + def CheckType(self, test, containing_type_desc, name, file_desc): + if containing_type_desc is None: + desc = file_desc.message_types_by_name[name] + expected_full_name = '.'.join([file_desc.package, name]) + else: + desc = containing_type_desc.nested_types_by_name[name] + expected_full_name = '.'.join([containing_type_desc.full_name, name]) + + test.assertEqual(name, desc.name) + test.assertEqual(expected_full_name, desc.full_name) + test.assertEqual(containing_type_desc, desc.containing_type) + test.assertEqual(desc.file, file_desc) + test.assertEqual(self.is_extendable, desc.is_extendable) + for name, subtype in self.type_dict.items(): + subtype.CheckType(test, desc, name, file_desc) + + for index, (name, field) in enumerate(self.field_list): + field.CheckField(test, desc, name, index) + + for index, (name, field) in enumerate(self.extensions): + field.CheckField(test, desc, name, index) + + +class EnumField(object): + + def __init__(self, number, type_name, default_value): + self.number = number + self.type_name = type_name + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + enum_desc = msg_desc.enum_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_ENUM, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_ENUM, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(enum_desc.values_by_name[self.default_value].index, + field_desc.default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(enum_desc, field_desc.enum_type) + + +class MessageField(object): + + def __init__(self, number, type_name): + self.number = number + self.type_name = type_name + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + field_type_desc = msg_desc.nested_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(field_type_desc, field_desc.message_type) + + +class StringField(object): + + def __init__(self, number, default_value): + self.number = number + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_STRING, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_STRING, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(self.default_value, field_desc.default_value) + + +class ExtensionField(object): + + def __init__(self, number, extended_type): + self.number = number + self.extended_type = extended_type + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.extensions_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(index, field_desc.index) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertTrue(field_desc.is_extension) + test.assertEqual(msg_desc, field_desc.extension_scope) + test.assertEqual(msg_desc, field_desc.message_type) + test.assertEqual(self.extended_type, field_desc.containing_type.name) + + +class AddDescriptorTest(basetest.TestCase): + + def _TestMessage(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes').full_name) + + # AddDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage') + + pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + # Files are implicitly also indexed when messages are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) + + def testMessage(self): + self._TestMessage('') + self._TestMessage('.') + + def _TestEnum(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum').full_name) + + # AddEnumDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum.NestedEnum') + + pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + # Files are implicitly also indexed when enums are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + + def testEnum(self): + self._TestEnum('') + self._TestEnum('.') + + def testFile(self): + pool = descriptor_pool.DescriptorPool() + pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + # AddFileDescriptor is not recursive; messages and enums within files must + # be explicitly registered. + with self.assertRaises(KeyError): + pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes') + + +TEST1_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test1.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest1': MessageType({ + 'NestedEnum': EnumType([('ALPHA', 1), ('BETA', 2)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('EPSILON', 5), ('ZETA', 6)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('ETA', 7), ('THETA', 8)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ETA')), + ('nested_field', StringField(2, 'theta')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ZETA')), + ('nested_field', StringField(2, 'beta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'BETA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], is_extendable=True), + + 'DescriptorPoolTest2': MessageType({ + 'NestedEnum': EnumType([('GAMMA', 3), ('DELTA', 4)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('IOTA', 9), ('KAPPA', 10)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('LAMBDA', 11), ('MU', 12)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'MU')), + ('nested_field', StringField(2, 'lambda')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'IOTA')), + ('nested_field', StringField(2, 'delta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'GAMMA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ]), + }) + + +TEST2_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test2.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest3': MessageType({ + 'NestedEnum': EnumType([('NU', 13), ('XI', 14)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('OMICRON', 15), ('PI', 16)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('RHO', 17), ('SIGMA', 18)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'RHO')), + ('nested_field', StringField(2, 'sigma')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'PI')), + ('nested_field', StringField(2, 'nu')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'XI')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], extensions=[ + ('descriptor_pool_test', + ExtensionField(1001, 'DescriptorPoolTest1')), + ]), + }, + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test1.proto b/python/google/protobuf/internal/descriptor_pool_test1.proto new file mode 100644 index 00000000..c11dcc0e --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test1.proto @@ -0,0 +1,94 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + + +message DescriptorPoolTest1 { + extensions 1000 to max; + + enum NestedEnum { + ALPHA = 1; + BETA = 2; + } + + optional NestedEnum nested_enum = 1 [default = BETA]; + + message NestedMessage { + enum NestedEnum { + EPSILON = 5; + ZETA = 6; + } + optional NestedEnum nested_enum = 1 [default = ZETA]; + optional string nested_field = 2 [default = "beta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + ETA = 7; + THETA = 8; + } + optional NestedEnum nested_enum = 1 [default = ETA]; + optional string nested_field = 2 [default = "theta"]; + } + } + + optional NestedMessage nested_message = 2; +} + +message DescriptorPoolTest2 { + enum NestedEnum { + GAMMA = 3; + DELTA = 4; + } + + optional NestedEnum nested_enum = 1 [default = GAMMA]; + + message NestedMessage { + enum NestedEnum { + IOTA = 9; + KAPPA = 10; + } + optional NestedEnum nested_enum = 1 [default = IOTA]; + optional string nested_field = 2 [default = "delta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + LAMBDA = 11; + MU = 12; + } + optional NestedEnum nested_enum = 1 [default = MU]; + optional string nested_field = 2 [default = "lambda"]; + } + } + + optional NestedMessage nested_message = 2; +} diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto new file mode 100644 index 00000000..d97d39b4 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test2.proto @@ -0,0 +1,70 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +import "google/protobuf/internal/descriptor_pool_test1.proto"; + + +message DescriptorPoolTest3 { + + extend DescriptorPoolTest1 { + optional DescriptorPoolTest3 descriptor_pool_test = 1001; + } + + enum NestedEnum { + NU = 13; + XI = 14; + } + + optional NestedEnum nested_enum = 1 [default = XI]; + + message NestedMessage { + enum NestedEnum { + OMICRON = 15; + PI = 16; + } + optional NestedEnum nested_enum = 1 [default = PI]; + optional string nested_field = 2 [default = "nu"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + RHO = 17; + SIGMA = 18; + } + optional NestedEnum nested_enum = 1 [default = RHO]; + optional string nested_field = 2 [default = "sigma"]; + } + } + + optional NestedMessage nested_message = 2; +} + diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py new file mode 100644 index 00000000..b3a15710 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for descriptor.py for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.descriptor_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index c74f882e..d20d9457 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -34,7 +34,7 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 @@ -48,7 +48,7 @@ name: 'TestEmptyMessage' """ -class DescriptorTest(unittest.TestCase): +class DescriptorTest(basetest.TestCase): def setUp(self): self.my_file = descriptor.FileDescriptor( @@ -244,7 +244,7 @@ class DescriptorTest(unittest.TestCase): unittest_custom_options_pb2.double_opt]) self.assertEqual("Hello, \"World\"", message_options.Extensions[ unittest_custom_options_pb2.string_opt]) - self.assertEqual("Hello\0World", message_options.Extensions[ + self.assertEqual(b"Hello\0World", message_options.Extensions[ unittest_custom_options_pb2.bytes_opt]) dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum self.assertEqual( @@ -395,7 +395,7 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_file.package, 'protobuf_unittest') -class DescriptorCopyToProtoTest(unittest.TestCase): +class DescriptorCopyToProtoTest(basetest.TestCase): """Tests for CopyTo functions of Descriptor.""" def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii): @@ -530,47 +530,49 @@ class DescriptorCopyToProtoTest(unittest.TestCase): descriptor_pb2.DescriptorProto, TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII) - def testCopyToProto_FileDescriptor(self): - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" - name: 'google/protobuf/unittest_import.proto' - package: 'protobuf_unittest_import' - dependency: 'google/protobuf/unittest_import_public.proto' - message_type: < - name: 'ImportMessage' - field: < - name: 'd' - number: 1 - label: 1 # Optional - type: 5 # TYPE_INT32 - > - > - """ + - """enum_type: < - name: 'ImportEnum' - value: < - name: 'IMPORT_FOO' - number: 7 - > - value: < - name: 'IMPORT_BAR' - number: 8 - > - value: < - name: 'IMPORT_BAZ' - number: 9 - > - > - options: < - java_package: 'com.google.protobuf.test' - optimize_for: 1 # SPEED - > - public_dependency: 0 - """) - - self._InternalTestCopyToProto( - unittest_import_pb2.DESCRIPTOR, - descriptor_pb2.FileDescriptorProto, - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) + # Disable this test so we can make changes to the proto file. + # TODO(xiaofeng): Enable this test after cl/55530659 is submitted. + # + # def testCopyToProto_FileDescriptor(self): + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" + # name: 'google/protobuf/unittest_import.proto' + # package: 'protobuf_unittest_import' + # dependency: 'google/protobuf/unittest_import_public.proto' + # message_type: < + # name: 'ImportMessage' + # field: < + # name: 'd' + # number: 1 + # label: 1 # Optional + # type: 5 # TYPE_INT32 + # > + # > + # """ + + # """enum_type: < + # name: 'ImportEnum' + # value: < + # name: 'IMPORT_FOO' + # number: 7 + # > + # value: < + # name: 'IMPORT_BAR' + # number: 8 + # > + # value: < + # name: 'IMPORT_BAZ' + # number: 9 + # > + # > + # options: < + # java_package: 'com.google.protobuf.test' + # optimize_for: 1 # SPEED + # > + # public_dependency: 0 + # """) + # self._InternalTestCopyToProto( + # unittest_import_pb2.DESCRIPTOR, + # descriptor_pb2.FileDescriptorProto, + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) def testCopyToProto_ServiceDescriptor(self): TEST_SERVICE_ASCII = """ @@ -586,28 +588,82 @@ class DescriptorCopyToProtoTest(unittest.TestCase): output_type: '.protobuf_unittest.BarResponse' > """ - self._InternalTestCopyToProto( unittest_pb2.TestService.DESCRIPTOR, descriptor_pb2.ServiceDescriptorProto, TEST_SERVICE_ASCII) -class MakeDescriptorTest(unittest.TestCase): +class MakeDescriptorTest(basetest.TestCase): + + def testMakeDescriptorWithNestedFields(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo2' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + nested_type = message_type.nested_type.add() + nested_type.name = 'Sub' + enum_type = nested_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + field = message_type.field.add() + field.number = 2 + field.name = 'nested_message_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_MESSAGE + field.type_name = 'Sub' + enum_field = nested_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo2.Sub.FOO' + + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + self.assertEqual(result.fields[1].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_MESSAGE) + self.assertEqual(result.fields[1].message_type.containing_type, + result) + self.assertEqual(result.nested_types[0].fields[0].full_name, + 'Foo2.Sub.bar_field') + self.assertEqual(result.nested_types[0].fields[0].enum_type, + result.nested_types[0].enum_types[0]) + def testMakeDescriptorWithUnsignedIntField(self): file_descriptor_proto = descriptor_pb2.FileDescriptorProto() file_descriptor_proto.name = 'Foo' message_type = file_descriptor_proto.message_type.add() message_type.name = file_descriptor_proto.name + enum_type = message_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 field = message_type.field.add() field.number = 1 field.name = 'uint64_field' field.label = descriptor.FieldDescriptor.LABEL_REQUIRED field.type = descriptor.FieldDescriptor.TYPE_UINT64 + enum_field = message_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo.FOO' + result = descriptor.MakeDescriptor(message_type) self.assertEqual(result.fields[0].cpp_type, descriptor.FieldDescriptor.CPPTYPE_UINT64) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 777975e8..0a7c0417 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for encoding protocol message primitives. Contains the logic for encoding every logical protocol field type @@ -67,6 +71,8 @@ sizer rather than when calling them. In particular: __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import wire_format @@ -340,7 +346,8 @@ def MessageSetItemSizer(field_number): def _VarintEncoder(): """Return an encoder for a basic varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeVarint(write, value): bits = value & 0x7f value >>= 7 @@ -357,7 +364,8 @@ def _SignedVarintEncoder(): """Return an encoder for a basic signed varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeSignedVarint(write, value): if value < 0: value += (1 << 64) @@ -382,7 +390,8 @@ def _VarintBytes(value): pieces = [] _EncodeVarint(pieces.append, value) - return "".join(pieces) + return "".encode("latin1").join(pieces) ##PY25 +##!PY25 return b"".join(pieces) def TagBytes(field_number, wire_type): @@ -520,26 +529,33 @@ def _FloatingPointEncoder(wire_type, format): format: The format string to pass to struct.pack(). """ + b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25 value_size = struct.calcsize(format) if value_size == 4: def EncodeNonFiniteOrRaise(write, value): # Remember that the serialized form uses little-endian byte order. if value == _POS_INF: - write('\x00\x00\x80\x7F') + write(b('\x00\x00\x80\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x80\x7F') elif value == _NEG_INF: - write('\x00\x00\x80\xFF') + write(b('\x00\x00\x80\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x80\xFF') elif value != value: # NaN - write('\x00\x00\xC0\x7F') + write(b('\x00\x00\xC0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\xC0\x7F') else: raise elif value_size == 8: def EncodeNonFiniteOrRaise(write, value): if value == _POS_INF: - write('\x00\x00\x00\x00\x00\x00\xF0\x7F') + write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') elif value == _NEG_INF: - write('\x00\x00\x00\x00\x00\x00\xF0\xFF') + write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') elif value != value: # NaN - write('\x00\x00\x00\x00\x00\x00\xF8\x7F') + write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') else: raise else: @@ -615,8 +631,10 @@ DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d') def BoolEncoder(field_number, is_repeated, is_packed): """Returns an encoder for a boolean field.""" - false_byte = chr(0) - true_byte = chr(1) +##!PY25 false_byte = b'\x00' +##!PY25 true_byte = b'\x01' + false_byte = '\x00'.encode('latin1') ##PY25 + true_byte = '\x01'.encode('latin1') ##PY25 if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint @@ -752,7 +770,8 @@ def MessageSetItemEncoder(field_number): } } """ - start_bytes = "".join([ + start_bytes = "".encode("latin1").join([ ##PY25 +##!PY25 start_bytes = b"".join([ TagBytes(1, wire_format.WIRETYPE_START_GROUP), TagBytes(2, wire_format.WIRETYPE_VARINT), _VarintBytes(field_number), diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto index 9f55e037..03dcb2ca 100644 --- a/python/google/protobuf/internal/factory_test1.proto +++ b/python/google/protobuf/internal/factory_test1.proto @@ -52,4 +52,6 @@ message Factory1Message { optional NestedFactory1Message nested_factory_1_message = 3; optional int32 scalar_value = 4; repeated string list_value = 5; + + extensions 1000 to max; } diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto index d3ce4d7f..a8c68123 100644 --- a/python/google/protobuf/internal/factory_test2.proto +++ b/python/google/protobuf/internal/factory_test2.proto @@ -70,8 +70,23 @@ message Factory2Message { optional string string_with_default = 18 [default = "hello world"]; optional bool bool_with_default = 19 [default = false]; optional Factory2Enum enum_with_default = 20 [default = FACTORY_2_VALUE_1]; + optional bytes bytes_with_default = 21 [default = "a\373\000c"]; + + + extend Factory1Message { + optional string one_more_field = 1001; + } + + oneof oneof_field { + int32 oneof_int = 22; + string oneof_string = 23; + } } message LoopMessage { optional Factory2Message loop = 1; } + +extend Factory1Message { + optional string another_field = 1002; +} diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 8343aba1..5818060c 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -41,20 +41,21 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 -from google.protobuf import unittest_pb2 from google.protobuf import unittest_no_generic_services_pb2 +from google.protobuf import unittest_pb2 from google.protobuf import service +from google.protobuf import symbol_database MAX_EXTENSION = 536870912 -class GeneratorTest(unittest.TestCase): +class GeneratorTest(basetest.TestCase): def testNestedMessageDescriptor(self): field_name = 'optional_nested_message' @@ -217,6 +218,10 @@ class GeneratorTest(unittest.TestCase): 'google/protobuf/unittest.proto') self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest') self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None) + self.assertEqual(unittest_pb2.DESCRIPTOR.dependencies, + [unittest_import_pb2.DESCRIPTOR]) + self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies, + [unittest_import_public_pb2.DESCRIPTOR]) def testNoGenericServices(self): self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage")) @@ -241,6 +246,18 @@ class GeneratorTest(unittest.TestCase): unittest_pb2._TESTALLTYPES_NESTEDMESSAGE.name in file_type.message_types_by_name) + def testEnumTypesByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2._FOREIGNENUM, + file_type.enum_types_by_name[unittest_pb2._FOREIGNENUM.name]) + + def testExtensionsByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2.my_extension_string, + file_type.extensions_by_name[unittest_pb2.my_extension_string.name]) + def testPublicImports(self): # Test public imports as embedded message. all_type_proto = unittest_pb2.TestAllTypes() @@ -265,5 +282,62 @@ class GeneratorTest(unittest.TestCase): self.assertEqual(message.Extensions[test_bad_identifiers_pb2.service], "qux") + def testOneof(self): + desc = unittest_pb2.TestAllTypes.DESCRIPTOR + self.assertEqual(1, len(desc.oneofs)) + self.assertEqual('oneof_field', desc.oneofs[0].name) + self.assertEqual(0, desc.oneofs[0].index) + self.assertIs(desc, desc.oneofs[0].containing_type) + self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field']) + nested_names = set(['oneof_uint32', 'oneof_nested_message', + 'oneof_string', 'oneof_bytes']) + self.assertSameElements( + nested_names, + [field.name for field in desc.oneofs[0].fields]) + for field_name, field_desc in desc.fields_by_name.iteritems(): + if field_name in nested_names: + self.assertIs(desc.oneofs[0], field_desc.containing_oneof) + else: + self.assertIsNone(field_desc.containing_oneof) + + +class SymbolDatabaseRegistrationTest(basetest.TestCase): + """Checks that messages, enums and files are correctly registered.""" + + def testGetSymbol(self): + self.assertEquals( + unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes')) + self.assertEquals( + unittest_pb2.TestAllTypes.NestedMessage, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.NestedMessage')) + with self.assertRaises(KeyError): + symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage') + self.assertEquals( + unittest_pb2.TestAllTypes.OptionalGroup, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.OptionalGroup')) + self.assertEquals( + unittest_pb2.TestAllTypes.RepeatedGroup, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.RepeatedGroup')) + + def testEnums(self): + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + symbol_database.Default().pool.FindEnumTypeByName( + 'protobuf_unittest.ForeignEnum').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + symbol_database.Default().pool.FindEnumTypeByName( + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + def testFindFileByName(self): + self.assertEquals( + 'google/protobuf/unittest.proto', + symbol_database.Default().pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/message_factory_cpp2_test.py b/python/google/protobuf/internal/message_factory_cpp2_test.py new file mode 100644 index 00000000..fb52e1b1 --- /dev/null +++ b/python/google/protobuf/internal/message_factory_cpp2_test.py @@ -0,0 +1,56 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.message_factory.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.message_factory_test import * + + +class ConfirmCppApi2Test(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_factory_python_test.py b/python/google/protobuf/internal/message_factory_python_test.py new file mode 100644 index 00000000..6a2053ff --- /dev/null +++ b/python/google/protobuf/internal/message_factory_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for ..public.message_factory for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.message_factory_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 0bc9be99..c53d77bf 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -34,7 +34,7 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' -import unittest +from google.apputils import basetest from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 @@ -43,7 +43,7 @@ from google.protobuf import descriptor_pool from google.protobuf import message_factory -class MessageFactoryTest(unittest.TestCase): +class MessageFactoryTest(basetest.TestCase): def setUp(self): self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( @@ -61,8 +61,8 @@ class MessageFactoryTest(unittest.TestCase): msg.factory_1_message.nested_factory_1_message.value = ( 'nested message value') msg.factory_1_message.scalar_value = 22 - msg.factory_1_message.list_value.extend(['one', 'two', 'three']) - msg.factory_1_message.list_value.append('four') + msg.factory_1_message.list_value.extend([u'one', u'two', u'three']) + msg.factory_1_message.list_value.append(u'four') msg.factory_1_enum = 1 msg.nested_factory_1_enum = 0 msg.nested_factory_1_message.value = 'nested message value' @@ -70,8 +70,8 @@ class MessageFactoryTest(unittest.TestCase): msg.circular_message.circular_message.mandatory = 2 msg.circular_message.scalar_value = 'one deep' msg.scalar_value = 'zero deep' - msg.list_value.extend(['four', 'three', 'two']) - msg.list_value.append('one') + msg.list_value.extend([u'four', u'three', u'two']) + msg.list_value.append(u'one') msg.grouped.add() msg.grouped[0].part_1 = 'hello' msg.grouped[0].part_2 = 'world' @@ -92,22 +92,40 @@ class MessageFactoryTest(unittest.TestCase): db.Add(self.factory_test2_fd) factory = message_factory.MessageFactory() cls = factory.GetPrototype(pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory2Message')) + 'google.protobuf.python.internal.Factory2Message')) self.assertIsNot(cls, factory_test2_pb2.Factory2Message) self._ExerciseDynamicClass(cls) cls2 = factory.GetPrototype(pool.FindMessageTypeByName( - 'net.proto2.python.internal.Factory2Message')) + 'google.protobuf.python.internal.Factory2Message')) self.assertIs(cls, cls2) def testGetMessages(self): - messages = message_factory.GetMessages([self.factory_test2_fd, - self.factory_test1_fd]) - self.assertContainsSubset( - ['net.proto2.python.internal.Factory2Message', - 'net.proto2.python.internal.Factory1Message'], - messages.keys()) - self._ExerciseDynamicClass( - messages['net.proto2.python.internal.Factory2Message']) + # performed twice because multiple calls with the same input must be allowed + for _ in range(2): + messages = message_factory.GetMessages([self.factory_test2_fd, + self.factory_test1_fd]) + self.assertContainsSubset( + ['google.protobuf.python.internal.Factory2Message', + 'google.protobuf.python.internal.Factory1Message'], + messages.keys()) + self._ExerciseDynamicClass( + messages['google.protobuf.python.internal.Factory2Message']) + self.assertContainsSubset( + ['google.protobuf.python.internal.Factory2Message.one_more_field', + 'google.protobuf.python.internal.another_field'], + (messages['google.protobuf.python.internal.Factory1Message'] + ._extensions_by_name.keys())) + factory_msg1 = messages['google.protobuf.python.internal.Factory1Message'] + msg1 = messages['google.protobuf.python.internal.Factory1Message']() + ext1 = factory_msg1._extensions_by_name[ + 'google.protobuf.python.internal.Factory2Message.one_more_field'] + ext2 = factory_msg1._extensions_by_name[ + 'google.protobuf.python.internal.another_field'] + msg1.Extensions[ext1] = 'test1' + msg1.Extensions[ext2] = 'test2' + self.assertEquals('test1', msg1.Extensions[ext1]) + self.assertEquals('test2', msg1.Extensions[ext2]) + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/message_cpp_test.py b/python/google/protobuf/internal/message_python_test.py index 0d84b320..baf15044 100644 --- a/python/google/protobuf/internal/message_cpp_test.py +++ b/python/google/protobuf/internal/message_python_test.py @@ -30,16 +30,25 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Tests for google.protobuf.internal.message_cpp.""" - -__author__ = 'shahms@google.com (Shahms King)' +"""Tests for ..public.message for the pure Python implementation.""" import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' -import unittest +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import from google.protobuf.internal.message_test import * +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 53e9d507..f4c4ae00 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -47,9 +47,9 @@ import copy import math import operator import pickle +import sys -import unittest -from google.protobuf import unittest_import_pb2 +from google.apputils import basetest from google.protobuf import unittest_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util @@ -68,10 +68,23 @@ def IsPosInf(val): def IsNegInf(val): return isinf(val) and (val < 0) -class MessageTest(unittest.TestCase): + +class MessageTest(basetest.TestCase): + + def testBadUtf8String(self): + if api_implementation.Type() != 'python': + self.skipTest("Skipping testBadUtf8String, currently only the python " + "api implementation raises UnicodeDecodeError when a " + "string field contains bad utf-8.") + bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') + with self.assertRaises(UnicodeDecodeError) as context: + unittest_pb2.TestAllTypes.FromString(bad_utf8_data) + self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string', + str(context.exception)) def testGoldenMessage(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData( + 'golden_message_oneof_implemented') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) test_util.ExpectAllFieldsSet(self, golden_message) @@ -80,7 +93,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenExtensions(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestAllExtensions() @@ -91,7 +104,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedMessage(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedTypes() @@ -102,7 +115,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedExtensions(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedExtensions() @@ -113,7 +126,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testPickleSupport(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) pickled_message = pickle.dumps(golden_message) @@ -121,6 +134,7 @@ class MessageTest(unittest.TestCase): unpickled_message = pickle.loads(pickled_message) self.assertEquals(unpickled_message, golden_message) + def testPickleIncompleteProto(self): golden_message = unittest_pb2.TestRequired(a=1) pickled_message = pickle.dumps(golden_message) @@ -132,10 +146,10 @@ class MessageTest(unittest.TestCase): self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) def testPositiveInfinity(self): - golden_data = ('\x5D\x00\x00\x80\x7F' - '\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' - '\xCD\x02\x00\x00\x80\x7F' - '\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCD\x02\x00\x00\x80\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.optional_float)) @@ -145,10 +159,10 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinity(self): - golden_data = ('\x5D\x00\x00\x80\xFF' - '\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' - '\xCD\x02\x00\x00\x80\xFF' - '\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCD\x02\x00\x00\x80\xFF' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.optional_float)) @@ -158,10 +172,10 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumber(self): - golden_data = ('\x5D\x00\x00\xC0\x7F' - '\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' - '\xCD\x02\x00\x00\xC0\x7F' - '\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_data = (b'\x5D\x00\x00\xC0\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' + b'\xCD\x02\x00\x00\xC0\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.optional_float)) @@ -182,8 +196,8 @@ class MessageTest(unittest.TestCase): self.assertTrue(isnan(message.repeated_double[0])) def testPositiveInfinityPacked(self): - golden_data = ('\xA2\x06\x04\x00\x00\x80\x7F' - '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.packed_float[0])) @@ -191,8 +205,8 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinityPacked(self): - golden_data = ('\xA2\x06\x04\x00\x00\x80\xFF' - '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.packed_float[0])) @@ -200,8 +214,8 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumberPacked(self): - golden_data = ('\xA2\x06\x04\x00\x00\xC0\x7F' - '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.packed_float[0])) @@ -303,6 +317,26 @@ class MessageTest(unittest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) + def testFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_float = 2.0 + self.assertEqual(str(message), 'optional_float: 2.0\n') + + def testHighPrecisionFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_double = 0.12345678912345678 + if sys.version_info.major >= 3: + self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') + else: + self.assertEqual(str(message), 'optional_double: 0.123456789123\n') + + def testUnknownFieldPrinting(self): + populated = unittest_pb2.TestAllTypes() + test_util.SetAllNonLazyFields(populated) + empty = unittest_pb2.TestEmptyMessage() + empty.ParseFromString(populated.SerializeToString()) + self.assertEqual(str(empty), '') + def testSortingRepeatedScalarFieldsDefaultComparator(self): """Check some different types with the default comparator.""" message = unittest_pb2.TestAllTypes() @@ -332,13 +366,13 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_string[1], 'b') self.assertEqual(message.repeated_string[2], 'c') - message.repeated_bytes.append('a') - message.repeated_bytes.append('c') - message.repeated_bytes.append('b') + message.repeated_bytes.append(b'a') + message.repeated_bytes.append(b'c') + message.repeated_bytes.append(b'b') message.repeated_bytes.sort() - self.assertEqual(message.repeated_bytes[0], 'a') - self.assertEqual(message.repeated_bytes[1], 'b') - self.assertEqual(message.repeated_bytes[2], 'c') + self.assertEqual(message.repeated_bytes[0], b'a') + self.assertEqual(message.repeated_bytes[1], b'b') + self.assertEqual(message.repeated_bytes[2], b'c') def testSortingRepeatedScalarFieldsCustomComparator(self): """Check some different types with custom comparator.""" @@ -347,7 +381,7 @@ class MessageTest(unittest.TestCase): message.repeated_int32.append(-3) message.repeated_int32.append(-2) message.repeated_int32.append(-1) - message.repeated_int32.sort(lambda x,y: cmp(abs(x), abs(y))) + message.repeated_int32.sort(key=abs) self.assertEqual(message.repeated_int32[0], -1) self.assertEqual(message.repeated_int32[1], -2) self.assertEqual(message.repeated_int32[2], -3) @@ -355,7 +389,7 @@ class MessageTest(unittest.TestCase): message.repeated_string.append('aaa') message.repeated_string.append('bb') message.repeated_string.append('c') - message.repeated_string.sort(lambda x,y: cmp(len(x), len(y))) + message.repeated_string.sort(key=len) self.assertEqual(message.repeated_string[0], 'c') self.assertEqual(message.repeated_string[1], 'bb') self.assertEqual(message.repeated_string[2], 'aaa') @@ -370,7 +404,7 @@ class MessageTest(unittest.TestCase): message.repeated_nested_message.add().bb = 6 message.repeated_nested_message.add().bb = 5 message.repeated_nested_message.add().bb = 4 - message.repeated_nested_message.sort(lambda x,y: cmp(x.bb, y.bb)) + message.repeated_nested_message.sort(key=operator.attrgetter('bb')) self.assertEqual(message.repeated_nested_message[0].bb, 1) self.assertEqual(message.repeated_nested_message[1].bb, 2) self.assertEqual(message.repeated_nested_message[2].bb, 3) @@ -396,6 +430,7 @@ class MessageTest(unittest.TestCase): message.repeated_nested_message.sort(key=get_bb, reverse=True) self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1]) + if sys.version_info.major >= 3: return # No cmp sorting in PY3. message.repeated_nested_message.sort(sort_function=cmp_bb) self.assertEqual([k.bb for k in message.repeated_nested_message], [1, 2, 3, 4, 5, 6]) @@ -407,7 +442,6 @@ class MessageTest(unittest.TestCase): """Check sorting a scalar field using list.sort() arguments.""" message = unittest_pb2.TestAllTypes() - abs_cmp = lambda a, b: cmp(abs(a), abs(b)) message.repeated_int32.append(-3) message.repeated_int32.append(-2) message.repeated_int32.append(-1) @@ -415,12 +449,13 @@ class MessageTest(unittest.TestCase): self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) message.repeated_int32.sort(key=abs, reverse=True) self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - message.repeated_int32.sort(sort_function=abs_cmp) - self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) - message.repeated_int32.sort(cmp=abs_cmp, reverse=True) - self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + if sys.version_info.major < 3: # No cmp sorting in PY3. + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - len_cmp = lambda a, b: cmp(len(a), len(b)) message.repeated_string.append('aaa') message.repeated_string.append('bb') message.repeated_string.append('c') @@ -428,10 +463,47 @@ class MessageTest(unittest.TestCase): self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) message.repeated_string.sort(key=len, reverse=True) self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - message.repeated_string.sort(sort_function=len_cmp) - self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) - message.repeated_string.sort(cmp=len_cmp, reverse=True) - self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + if sys.version_info.major < 3: # No cmp sorting in PY3. + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testRepeatedFieldsComparable(self): + m1 = unittest_pb2.TestAllTypes() + m2 = unittest_pb2.TestAllTypes() + m1.repeated_int32.append(0) + m1.repeated_int32.append(1) + m1.repeated_int32.append(2) + m2.repeated_int32.append(0) + m2.repeated_int32.append(1) + m2.repeated_int32.append(2) + m1.repeated_nested_message.add().bb = 1 + m1.repeated_nested_message.add().bb = 2 + m1.repeated_nested_message.add().bb = 3 + m2.repeated_nested_message.add().bb = 1 + m2.repeated_nested_message.add().bb = 2 + m2.repeated_nested_message.add().bb = 3 + + if sys.version_info.major >= 3: return # No cmp() in PY3. + + # These comparisons should not raise errors. + _ = m1 < m2 + _ = m1.repeated_nested_message < m2.repeated_nested_message + + # Make sure cmp always works. If it wasn't defined, these would be + # id() comparisons and would all fail. + self.assertEqual(cmp(m1, m2), 0) + self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) + self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) + self.assertEqual(cmp(m1.repeated_nested_message, + m2.repeated_nested_message), 0) + with self.assertRaises(TypeError): + # Can't compare repeated composite containers to lists. + cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) + + # TODO(anuraag): Implement extensiondict comparison in C++ and then add test def testParsingMerge(self): """Check the merge behavior when a required or optional field appears @@ -482,6 +554,87 @@ class MessageTest(unittest.TestCase): self.assertEqual(len(parsing_merge.Extensions[ unittest_pb2.TestParsingMerge.repeated_ext]), 3) + def ensureNestedMessageExists(self, msg, attribute): + """Make sure that a nested message object exists. + + As soon as a nested message attribute is accessed, it will be present in the + _fields dict, without being marked as actually being set. + """ + getattr(msg, attribute) + self.assertFalse(msg.HasField(attribute)) + + def testOneofGetCaseNonexistingField(self): + m = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') + + def testOneofSemantics(self): + m = unittest_pb2.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + + m.oneof_uint32 = 11 + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + + m.oneof_string = u'foo' + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertTrue(m.HasField('oneof_string')) + + m.oneof_nested_message.bb = 11 + self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_string')) + self.assertTrue(m.HasField('oneof_nested_message')) + + m.oneof_bytes = b'bb' + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_nested_message')) + self.assertTrue(m.HasField('oneof_bytes')) + + def testOneofCompositeFieldReadAccess(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + + self.ensureNestedMessageExists(m, 'oneof_nested_message') + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertEqual(11, m.oneof_uint32) + + def testOneofHasField(self): + m = unittest_pb2.TestAllTypes() + self.assertFalse(m.HasField('oneof_field')) + m.oneof_uint32 = 11 + self.assertTrue(m.HasField('oneof_field')) + m.oneof_bytes = b'bb' + self.assertTrue(m.HasField('oneof_field')) + m.ClearField('oneof_bytes') + self.assertFalse(m.HasField('oneof_field')) + + def testOneofClearField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_field') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearSetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_uint32') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearUnsetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + self.ensureNestedMessageExists(m, 'oneof_nested_message') + m.ClearField('oneof_nested_message') + self.assertEqual(11, m.oneof_uint32) + self.assertTrue(m.HasField('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + + def testSortEmptyRepeatedCompositeContainer(self): """Exercise a scenario that has led to segfaults in the past. @@ -489,6 +642,35 @@ class MessageTest(unittest.TestCase): m = unittest_pb2.TestAllTypes() m.repeated_nested_message.sort() + def testHasFieldOnRepeatedField(self): + """Using HasField on a repeated field should raise an exception. + """ + m = unittest_pb2.TestAllTypes() + with self.assertRaises(ValueError) as _: + m.HasField('repeated_int32') + + +class ValidTypeNamesTest(basetest.TestCase): + + def assertImportFromName(self, msg, base_name): + # Parse <type 'module.class_name'> to extra 'some.name' as a string. + tp_name = str(type(msg)).split("'")[1] + valid_names = ('Repeated%sContainer' % base_name, + 'Repeated%sFieldContainer' % base_name) + self.assertTrue(any(tp_name.endswith(v) for v in valid_names), + '%r does end with any of %r' % (tp_name, valid_names)) + + parts = tp_name.split('.') + class_name = parts[-1] + module_name = '.'.join(parts[:-1]) + __import__(module_name, fromlist=[class_name]) + + def testTypeNamesCanBeImported(self): + # If import doesn't work, pickling won't work either. + pb = unittest_pb2.TestAllTypes() + self.assertImportFromName(pb.repeated_int32, 'Scalar') + self.assertImportFromName(pb.repeated_nested_message, 'Composite') + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto new file mode 100644 index 00000000..c9ae58b6 --- /dev/null +++ b/python/google/protobuf/internal/missing_enum_values.proto @@ -0,0 +1,50 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +message TestEnumValues { + enum NestedEnum { + ZERO = 0; + ONE = 1; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} + +message TestMissingEnumValues { + enum NestedEnum { + TWO = 2; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 4bea57ac..9ee352d6 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Keep it Python2.5 compatible for GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. +# # This code is meant to work on Python 2.4 and above only. # # TODO(robinson): Helpers for verbose, common checks like seeing if a @@ -50,11 +54,16 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO -import copy_reg +import sys +if sys.version_info[0] < 3: + try: + from cStringIO import StringIO as BytesIO + except ImportError: + from StringIO import StringIO as BytesIO + import copy_reg as copyreg +else: + from io import BytesIO + import copyreg import struct import weakref @@ -98,8 +107,8 @@ def InitMessage(descriptor, cls): _AddPropertiesForExtensions(descriptor, cls) _AddStaticMethods(cls) _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(cls) - copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) # Stateless helpers for GeneratedProtocolMessageType below. @@ -176,7 +185,8 @@ def _AddSlots(message_descriptor, dictionary): '_is_present_in_parent', '_listener', '_listener_for_children', - '__weakref__'] + '__weakref__', + '_oneofs'] def _IsMessageSetExtension(field): @@ -272,7 +282,7 @@ def _DefaultValueConstructorForField(field): message._listener_for_children, field.message_type) return MakeRepeatedMessageDefault else: - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker = type_checkers.GetTypeChecker(field) def MakeRepeatedScalarDefault(message): return containers.RepeatedScalarFieldContainer( message._listener_for_children, type_checker) @@ -301,6 +311,10 @@ def _AddInitMethod(message_descriptor, cls): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 self._fields = {} + # Contains a mapping from oneof field descriptors to the descriptor + # of the currently set field in that oneof field. + self._oneofs = {} + # _unknown_fields is () when empty for efficiency, and will be turned into # a list if fields are added. self._unknown_fields = () @@ -440,7 +454,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): """ proto_field_name = field.name property_name = _PropertyName(proto_field_name) - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker = type_checkers.GetTypeChecker(field) default_value = field.default_value valid_values = set() @@ -450,14 +464,21 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name - def setter(self, new_value): - type_checker.CheckValue(new_value) - self._fields[field] = new_value + def field_setter(self, new_value): + # pylint: disable=protected-access + self._fields[field] = type_checker.CheckValue(new_value) # Check _cached_byte_size_dirty inline to improve performance, since scalar # setters are called frequently. if not self._cached_byte_size_dirty: self._Modified() + if field.containing_oneof is not None: + def setter(self, new_value): + field_setter(self, new_value) + self._UpdateOneofState(field) + else: + setter = field_setter + setter.__module__ = None setter.__doc__ = 'Setter for %s.' % proto_field_name @@ -493,7 +514,10 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): if field_value is None: # Construct a new object to represent this field. field_value = message_type._concrete_class() # use field.message_type? - field_value._SetListener(self._listener_for_children) + field_value._SetListener( + _OneofListener(self, field) + if field.containing_oneof is not None + else self._listener_for_children) # Atomically check if another thread has preempted us and, if not, swap # in the new object we just created. If someone has preempted us, we @@ -589,6 +613,9 @@ def _AddHasFieldMethod(message_descriptor, cls): for field in message_descriptor.fields: if field.label != _FieldDescriptor.LABEL_REPEATED: singular_fields[field.name] = field + # Fields inside oneofs are never repeated (enforced by the compiler). + for field in message_descriptor.oneofs: + singular_fields[field.name] = field def HasField(self, field_name): try: @@ -597,11 +624,18 @@ def _AddHasFieldMethod(message_descriptor, cls): raise ValueError( 'Protocol message has no singular "%s" field.' % field_name) - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - value = self._fields.get(field) - return value is not None and value._is_present_in_parent + if isinstance(field, descriptor_mod.OneofDescriptor): + try: + return HasField(self, self._oneofs[field].name) + except KeyError: + return False else: - return field in self._fields + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields + cls.HasField = HasField @@ -611,7 +645,14 @@ def _AddClearFieldMethod(message_descriptor, cls): try: field = message_descriptor.fields_by_name[field_name] except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + try: + field = message_descriptor.oneofs_by_name[field_name] + if field in self._oneofs: + field = self._oneofs[field] + else: + return + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) if field in self._fields: # Note: If the field is a sub-message, its listener will still point @@ -619,6 +660,9 @@ def _AddClearFieldMethod(message_descriptor, cls): # will call _Modified() and invalidate our byte size. Big deal. del self._fields[field] + if self._oneofs.get(field.containing_oneof, None) is field: + del self._oneofs[field.containing_oneof] + # Always call _Modified() -- even if nothing was changed, this is # a mutating method, and thus calling it should cause the field to become # present in the parent message. @@ -773,7 +817,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def SerializePartialToString(self): - out = StringIO() + out = BytesIO() self._InternalSerialize(out.write) return out.getvalue() cls.SerializePartialToString = SerializePartialToString @@ -796,7 +840,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls): # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise message_mod.DecodeError('Unexpected end-group tag.') - except IndexError: + except (IndexError, TypeError): + # Now ord(buf[p:p+1]) == ord('') gets TypeError. raise message_mod.DecodeError('Truncated message.') except struct.error, e: raise message_mod.DecodeError(e) @@ -857,7 +902,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): errors.extend(self.FindInitializationErrors()) return False - for field, value in self._fields.iteritems(): + for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: for element in value: @@ -953,6 +998,24 @@ def _AddMergeFromMethod(cls): cls.MergeFrom = MergeFrom +def _AddWhichOneofMethod(message_descriptor, cls): + def WhichOneof(self, oneof_name): + """Returns the name of the currently set field inside a oneof, or None.""" + try: + field = message_descriptor.oneofs_by_name[oneof_name] + except KeyError: + raise ValueError( + 'Protocol message has no oneof "%s" field.' % oneof_name) + + nested_field = self._oneofs.get(field, None) + if nested_field is not None and self.HasField(nested_field.name): + return nested_field.name + else: + return None + + cls.WhichOneof = WhichOneof + + def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) @@ -972,9 +1035,9 @@ def _AddMessageMethods(message_descriptor, cls): _AddMergeFromStringMethod(message_descriptor, cls) _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) + _AddWhichOneofMethod(message_descriptor, cls) - -def _AddPrivateHelperMethods(cls): +def _AddPrivateHelperMethods(message_descriptor, cls): """Adds implementation of private helper methods to cls.""" def Modified(self): @@ -992,8 +1055,20 @@ def _AddPrivateHelperMethods(cls): self._is_present_in_parent = True self._listener.Modified() + def _UpdateOneofState(self, field): + """Sets field as the active field in its containing oneof. + + Will also delete currently active field in the oneof, if it is different + from the argument. Does not mark the message as modified. + """ + other_field = self._oneofs.setdefault(field.containing_oneof, field) + if other_field is not field: + del self._fields[other_field] + self._oneofs[field.containing_oneof] = field + cls._Modified = Modified cls.SetInParent = Modified + cls._UpdateOneofState = _UpdateOneofState class _Listener(object): @@ -1042,6 +1117,27 @@ class _Listener(object): pass +class _OneofListener(_Listener): + """Special listener implementation for setting composite oneof fields.""" + + def __init__(self, parent_message, field): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + field: The descriptor of the field being set in the parent message. + """ + super(_OneofListener, self).__init__(parent_message) + self._field = field + + def Modified(self): + """Also updates the state of the containing oneof in the parent message.""" + try: + self._parent_message_weakref._UpdateOneofState(self._field) + super(_OneofListener, self).Modified() + except ReferenceError: + pass + + # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... # TODO(robinson): Unify error handling of "unknown extension" crap. # TODO(robinson): Support iteritems()-style iteration over all @@ -1133,9 +1229,10 @@ class _ExtensionDict(object): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. type_checker = type_checkers.GetTypeChecker( - extension_handle.cpp_type, extension_handle.type) - type_checker.CheckValue(value) - self._extended_message._fields[extension_handle] = value + extension_handle) + # pylint: disable=protected-access + self._extended_message._fields[extension_handle] = ( + type_checker.CheckValue(value)) self._extended_message._Modified() def _FindExtensionByName(self, name): diff --git a/python/google/protobuf/internal/reflection_cpp_generated_test.py b/python/google/protobuf/internal/reflection_cpp2_generated_test.py index 2a0a5124..d7fce5fa 100755 --- a/python/google/protobuf/internal/reflection_cpp_generated_test.py +++ b/python/google/protobuf/internal/reflection_cpp2_generated_test.py @@ -37,17 +37,19 @@ __author__ = 'jasonh@google.com (Jason Hsueh)' import os os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' -import unittest +from google.apputils import basetest from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_dynamic_pb2 from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal.reflection_test import * -class ReflectionCppTest(unittest.TestCase): +class ReflectionCppTest(basetest.TestCase): def testImplementationSetting(self): self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) def testExtensionOfGeneratedTypeInDynamicFile(self): """Tests that a file built dynamically can extend a generated C++ type. @@ -87,5 +89,6 @@ class ReflectionCppTest(unittest.TestCase): pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index ed286461..b3c414c7 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -37,11 +37,12 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import copy import gc import operator import struct -import unittest +from google.apputils import basetest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -49,6 +50,7 @@ from google.protobuf import descriptor_pb2 from google.protobuf import descriptor from google.protobuf import message from google.protobuf import reflection +from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 @@ -102,7 +104,7 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(unittest.TestCase): +class ReflectionTest(basetest.TestCase): def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) @@ -533,7 +535,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(0.0, proto.optional_double) self.assertEqual(False, proto.optional_bool) self.assertEqual('', proto.optional_string) - self.assertEqual('', proto.optional_bytes) + self.assertEqual(b'', proto.optional_bytes) self.assertEqual(41, proto.default_int32) self.assertEqual(42, proto.default_int64) @@ -549,7 +551,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(52e3, proto.default_double) self.assertEqual(True, proto.default_bool) self.assertEqual('hello', proto.default_string) - self.assertEqual('world', proto.default_bytes) + self.assertEqual(b'world', proto.default_bytes) self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) self.assertEqual(unittest_import_pb2.IMPORT_BAR, @@ -566,6 +568,17 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes() self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + def testClearRemovesChildren(self): + # Make sure there aren't any implementation bugs that are only partially + # clearing the message (which can happen in the more complex C++ + # implementation which has parallel message lists). + proto = unittest_pb2.TestRequiredForeign() + for i in range(10): + proto.repeated_message.add() + proto2 = unittest_pb2.TestRequiredForeign() + proto.CopyFrom(proto2) + self.assertRaises(IndexError, lambda: proto.repeated_message[5]) + def testDisallowedAssignments(self): # It's illegal to assign values directly to repeated fields # or to nonrepeated composite fields. Ensure that this fails. @@ -594,6 +607,30 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + def testIntegerTypes(self): + def TestGetAndDeserialize(field_name, value, expected_type): + proto = unittest_pb2.TestAllTypes() + setattr(proto, field_name, value) + self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) + proto2 = unittest_pb2.TestAllTypes() + proto2.ParseFromString(proto.SerializeToString()) + self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) + + TestGetAndDeserialize('optional_int32', 1, int) + TestGetAndDeserialize('optional_int32', 1 << 30, int) + TestGetAndDeserialize('optional_uint32', 1 << 30, int) + if struct.calcsize('L') == 4: + # Python only has signed ints, so 32-bit python can't fit an uint32 + # in an int. + TestGetAndDeserialize('optional_uint32', 1 << 31, long) + else: + # 64-bit python can fit uint32 inside an int + TestGetAndDeserialize('optional_uint32', 1 << 31, int) + TestGetAndDeserialize('optional_int64', 1 << 30, long) + TestGetAndDeserialize('optional_int64', 1 << 60, long) + TestGetAndDeserialize('optional_uint64', 1 << 30, long) + TestGetAndDeserialize('optional_uint64', 1 << 60, long) + def testSingleScalarBoundsChecking(self): def TestMinAndMaxIntegers(field_name, expected_min, expected_max): pb = unittest_pb2.TestAllTypes() @@ -613,29 +650,6 @@ class ReflectionTest(unittest.TestCase): pb.optional_nested_enum = 1 self.assertEqual(1, pb.optional_nested_enum) - # Invalid enum values. - pb.optional_nested_enum = 0 - self.assertEqual(0, pb.optional_nested_enum) - - bytes_size_before = pb.ByteSize() - - pb.optional_nested_enum = 4 - self.assertEqual(4, pb.optional_nested_enum) - - pb.optional_nested_enum = 0 - self.assertEqual(0, pb.optional_nested_enum) - - # Make sure that setting the same enum field doesn't just add unknown - # fields (but overwrites them). - self.assertEqual(bytes_size_before, pb.ByteSize()) - - # Is the invalid value preserved after serialization? - serialized = pb.SerializeToString() - pb2 = unittest_pb2.TestAllTypes() - pb2.ParseFromString(serialized) - self.assertEqual(0, pb2.optional_nested_enum) - self.assertEqual(pb, pb2) - def testRepeatedScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) @@ -749,9 +763,9 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignEnum.items()) proto = unittest_pb2.TestAllTypes() - self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys()) - self.assertEqual([1, 2, 3], proto.NestedEnum.values()) - self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)], + self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], proto.NestedEnum.items()) def testRepeatedScalars(self): @@ -1155,6 +1169,14 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(required is not extendee_proto.Extensions[extension]) self.assertTrue(not extendee_proto.HasExtension(extension)) + def testRegisteredExtensions(self): + self.assertTrue('protobuf_unittest.optional_int32_extension' in + unittest_pb2.TestAllExtensions._extensions_by_name) + self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) + # Make sure extensions haven't been registered into types that shouldn't + # have any. + self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any # extension in B should change a.HasField('b') to True @@ -1451,6 +1473,19 @@ class ReflectionTest(unittest.TestCase): proto2 = unittest_pb2.TestAllExtensions() self.assertRaises(TypeError, proto1.CopyFrom, proto2) + def testDeepCopy(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto2 = copy.deepcopy(proto1) + self.assertEqual(1, proto2.optional_int32) + + proto1.repeated_int32.append(2) + proto1.repeated_int32.append(3) + container = copy.deepcopy(proto1.repeated_int32) + self.assertEqual([2, 3], container) + + # TODO(anuraag): Implement deepcopy for repeated composite / extension dict + def testClear(self): proto = unittest_pb2.TestAllTypes() # C++ implementation does not support lazy fields right now so leave it @@ -1496,11 +1531,23 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(6, foreign.c) nested.bb = 15 foreign.c = 16 - self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertFalse(proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) - self.assertTrue(not proto.HasField('optional_foreign_message')) + self.assertFalse(proto.HasField('optional_foreign_message')) self.assertEqual(0, proto.optional_foreign_message.c) + def testOneOf(self): + proto = unittest_pb2.TestAllTypes() + proto.oneof_uint32 = 10 + proto.oneof_nested_message.bb = 11 + self.assertEqual(11, proto.oneof_nested_message.bb) + self.assertFalse(proto.HasField('oneof_uint32')) + nested = proto.oneof_nested_message + proto.oneof_string = 'abc' + self.assertEqual('abc', proto.oneof_string) + self.assertEqual(11, nested.bb) + self.assertFalse(proto.HasField('oneof_nested_message')) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1571,6 +1618,40 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Errors are only available from the most recent C++ implementation.') + def testFileDescriptorErrors(self): + file_name = 'test_file_descriptor_errors.proto' + package_name = 'test_file_descriptor_errors.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = file_name + file_descriptor_proto.package = package_name + m1 = file_descriptor_proto.message_type.add() + m1.name = 'msg1' + # Compiles the proto into the C++ descriptor pool + descriptor.FileDescriptor( + file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + # Add a FileDescriptorProto that has duplicate symbols + another_file_name = 'another_test_file_descriptor_errors.proto' + file_descriptor_proto.name = another_file_name + m2 = file_descriptor_proto.message_type.add() + m2.name = 'msg2' + with self.assertRaises(TypeError) as cm: + descriptor.FileDescriptor( + another_file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % + getattr(cm.expected, '__name__', cm.expected)) + self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) + # Error message will say something about this definition being a + # duplicate, though we don't check the message exactly to avoid a + # dependency on the C++ logging code. + self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) + def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1588,17 +1669,15 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = str('Testing') self.assertEqual(proto.optional_string, unicode('Testing')) - if api_implementation.Type() == 'python': - # Values of type 'str' are also accepted as long as they can be - # encoded in UTF-8. - self.assertEqual(type(proto.optional_string), str) - # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. self.assertRaises(ValueError, - setattr, proto, 'optional_string', str('a\x80a')) - # Assign a 'str' object which contains a UTF-8 encoded string. - self.assertRaises(ValueError, - setattr, proto, 'optional_string', 'Тест') + setattr, proto, 'optional_string', b'a\x80a') + if str is bytes: # PY2 + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + else: + proto.optional_string = 'Тест' # No exception thrown. proto.optional_string = 'abc' @@ -1621,7 +1700,8 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(proto.ByteSize(), len(serialized)) raw = unittest_mset_pb2.RawMessageSet() - raw.MergeFromString(serialized) + bytes_read = raw.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_read) message2 = unittest_mset_pb2.TestMessageSetExtension2() @@ -1632,7 +1712,8 @@ class ReflectionTest(unittest.TestCase): # Check the actual bytes on the wire. self.assertTrue( raw.item[0].message.endswith(test_utf8_bytes)) - message2.MergeFromString(raw.item[0].message) + bytes_read = message2.MergeFromString(raw.item[0].message) + self.assertEqual(len(raw.item[0].message), bytes_read) self.assertEqual(type(message2.str), unicode) self.assertEqual(message2.str, test_utf8) @@ -1643,17 +1724,22 @@ class ReflectionTest(unittest.TestCase): # MergeFromString and thus has no way to throw the exception. # # The pure Python API always returns objects of type 'unicode' (UTF-8 - # encoded), or 'str' (in 7 bit ASCII). - bytes = raw.item[0].message.replace( - test_utf8_bytes, len(test_utf8_bytes) * '\xff') + # encoded), or 'bytes' (in 7 bit ASCII). + badbytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * b'\xff') unicode_decode_failed = False try: - message2.MergeFromString(bytes) - except UnicodeDecodeError as e: + message2.MergeFromString(badbytes) + except UnicodeDecodeError: unicode_decode_failed = True string_field = message2.str - self.assertTrue(unicode_decode_failed or type(string_field) == str) + self.assertTrue(unicode_decode_failed or type(string_field) is bytes) + + def testBytesInTextFormat(self): + proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') + self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', + unicode(proto)) def testEmptyNestedMessage(self): proto = unittest_pb2.TestAllTypes() @@ -1667,16 +1753,19 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.MergeFromString('') + bytes_read = proto.optional_nested_message.MergeFromString(b'') + self.assertEqual(0, bytes_read) self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.ParseFromString('') + proto.optional_nested_message.ParseFromString(b'') self.assertTrue(proto.HasField('optional_nested_message')) serialized = proto.SerializeToString() proto2 = unittest_pb2.TestAllTypes() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertTrue(proto2.HasField('optional_nested_message')) def testSetInParent(self): @@ -1690,7 +1779,7 @@ class ReflectionTest(unittest.TestCase): # into separate TestCase classes. -class TestAllTypesEqualityTest(unittest.TestCase): +class TestAllTypesEqualityTest(basetest.TestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() @@ -1706,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(unittest.TestCase): +class FullProtosEqualityTest(basetest.TestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1792,7 +1881,7 @@ class FullProtosEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(unittest.TestCase): +class ExtensionEqualityTest(basetest.TestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1825,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(unittest.TestCase): +class MutualRecursionEqualityTest(basetest.TestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1837,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(unittest.TestCase): +class ByteSizeTest(basetest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -2133,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(unittest.TestCase): +class SerializationTest(basetest.TestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() second_proto = unittest_pb2.TestAllTypes() serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllFields(self): @@ -2149,7 +2240,9 @@ class SerializationTest(unittest.TestCase): test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllExtensions(self): @@ -2157,7 +2250,19 @@ class SerializationTest(unittest.TestCase): second_proto = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(first_proto) serialized = first_proto.SerializeToString() - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) + self.assertEqual(first_proto, second_proto) + + def testSerializeWithOptionalGroup(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + first_proto.optionalgroup.a = 242 + serialized = first_proto.SerializeToString() + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeNegativeValues(self): @@ -2249,7 +2354,9 @@ class SerializationTest(unittest.TestCase): second_proto.optional_int32 = 100 second_proto.optional_nested_message.bb = 999 - second_proto.MergeFromString(serialized) + bytes_parsed = second_proto.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_parsed) + # Ensure that we append to repeated fields. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) # Ensure that we overwrite nonrepeatd scalars. @@ -2274,20 +2381,28 @@ class SerializationTest(unittest.TestCase): raw = unittest_mset_pb2.RawMessageSet() self.assertEqual(False, raw.DESCRIPTOR.GetOptions().message_set_wire_format) - raw.MergeFromString(serialized) + self.assertEqual( + len(serialized), + raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.MergeFromString(raw.item[0].message) + self.assertEqual( + len(raw.item[0].message), + message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) message2 = unittest_mset_pb2.TestMessageSetExtension2() - message2.MergeFromString(raw.item[1].message) + self.assertEqual( + len(raw.item[1].message), + message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. proto2 = unittest_mset_pb2.TestMessageSet() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) @@ -2327,7 +2442,9 @@ class SerializationTest(unittest.TestCase): # Parse message using the message set wire format. proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto.MergeFromString(serialized)) # Check that the message parsed well. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 @@ -2345,7 +2462,9 @@ class SerializationTest(unittest.TestCase): proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) # Now test with a int64 field set. proto = unittest_pb2.TestAllTypes() @@ -2355,7 +2474,9 @@ class SerializationTest(unittest.TestCase): # unknown. proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) def _CheckRaises(self, exc_class, callable_obj, exception): """This method checks if the excpetion type and message are as expected.""" @@ -2406,11 +2527,15 @@ class SerializationTest(unittest.TestCase): partial = proto.SerializePartialToString() proto2 = unittest_pb2.TestRequired() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) - proto2.ParseFromString(partial) + self.assertEqual( + len(partial), + proto2.MergeFromString(partial)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) @@ -2478,7 +2603,9 @@ class SerializationTest(unittest.TestCase): second_proto.packed_double.extend([1.0, 2.0]) second_proto.packed_sint32.append(4) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual([3, 1, 2], second_proto.packed_int32) self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) self.assertEqual([4], second_proto.packed_sint32) @@ -2511,7 +2638,10 @@ class SerializationTest(unittest.TestCase): unpacked = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(unpacked) packed = unittest_pb2.TestPackedTypes() - packed.MergeFromString(unpacked.SerializeToString()) + serialized = unpacked.SerializeToString() + self.assertEqual( + len(serialized), + packed.MergeFromString(serialized)) expected = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(expected) self.assertEqual(expected, packed) @@ -2520,7 +2650,10 @@ class SerializationTest(unittest.TestCase): packed = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(packed) unpacked = unittest_pb2.TestUnpackedTypes() - unpacked.MergeFromString(packed.SerializeToString()) + serialized = packed.SerializeToString() + self.assertEqual( + len(serialized), + unpacked.MergeFromString(serialized)) expected = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(expected) self.assertEqual(expected, unpacked) @@ -2572,7 +2705,7 @@ class SerializationTest(unittest.TestCase): optional_int32=1, optional_string='foo', optional_bool=True, - optional_bytes='bar', + optional_bytes=b'bar', optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), optional_foreign_message=unittest_pb2.ForeignMessage(c=1), optional_nested_enum=unittest_pb2.TestAllTypes.FOO, @@ -2590,7 +2723,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1, proto.optional_int32) self.assertEqual('foo', proto.optional_string) self.assertEqual(True, proto.optional_bool) - self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(b'bar', proto.optional_bytes) self.assertEqual(1, proto.optional_nested_message.bb) self.assertEqual(1, proto.optional_foreign_message.c) self.assertEqual(unittest_pb2.TestAllTypes.FOO, @@ -2640,7 +2773,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(unittest.TestCase): +class OptionsTest(basetest.TestCase): def testMessageOptions(self): proto = unittest_mset_pb2.TestMessageSet() @@ -2667,5 +2800,135 @@ class OptionsTest(unittest.TestCase): +class ClassAPITest(basetest.TestCase): + + def testMakeClassWithNestedDescriptor(self): + leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', + containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + child_desc = descriptor.Descriptor('child', 'package.parent.child', '', + containing_type=None, fields=[], + nested_types=[leaf_desc], enum_types=[], + extensions=[]) + sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', + '', containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + parent_desc = descriptor.Descriptor('parent', 'package.parent', '', + containing_type=None, fields=[], + nested_types=[child_desc, sibling_desc], + enum_types=[], extensions=[]) + message_class = reflection.MakeClass(parent_desc) + self.assertIn('child', message_class.__dict__) + self.assertIn('sibling', message_class.__dict__) + self.assertIn('leaf', message_class.child.__dict__) + + def _GetSerializedFileDescriptor(self, name): + """Get a serialized representation of a test FileDescriptorProto. + + Args: + name: All calls to this must use a unique message name, to avoid + collisions in the cpp descriptor pool. + Returns: + A string containing the serialized form of a test FileDescriptorProto. + """ + file_descriptor_str = ( + 'message_type {' + ' name: "' + name + '"' + ' field {' + ' name: "flat"' + ' number: 1' + ' label: LABEL_REPEATED' + ' type: TYPE_UINT32' + ' }' + ' field {' + ' name: "bar"' + ' number: 2' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Bar"' + ' }' + ' nested_type {' + ' name: "Bar"' + ' field {' + ' name: "baz"' + ' number: 3' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Baz"' + ' }' + ' nested_type {' + ' name: "Baz"' + ' enum_type {' + ' name: "deep_enum"' + ' value {' + ' name: "VALUE_A"' + ' number: 0' + ' }' + ' }' + ' field {' + ' name: "deep"' + ' number: 4' + ' label: LABEL_OPTIONAL' + ' type: TYPE_UINT32' + ' }' + ' }' + ' }' + '}') + file_descriptor = descriptor_pb2.FileDescriptorProto() + text_format.Merge(file_descriptor_str, file_descriptor) + return file_descriptor.SerializeToString() + + def testParsingFlatClassWithExplicitClassDeclaration(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + + class MessageClass(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = msg_descriptor + msg = MessageClass() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingFlatClass(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingNestedClass(self): + """Test that the generated class can parse a nested message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'bar {' + ' baz {' + ' deep: 4' + ' }' + '}') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.bar.baz.deep, 4) + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index e04f8252..ef0981d9 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -34,13 +34,13 @@ __author__ = 'petar@google.com (Petar Petrov)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service -class FooUnitTest(unittest.TestCase): +class FooUnitTest(basetest.TestCase): def testService(self): class MockRpcChannel(service.RpcChannel): @@ -133,4 +133,4 @@ class FooUnitTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py new file mode 100644 index 00000000..80bc8d6e --- /dev/null +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -0,0 +1,120 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.symbol_database.""" + +from google.apputils import basetest +from google.protobuf import unittest_pb2 +from google.protobuf import symbol_database + + +class SymbolDatabaseTest(basetest.TestCase): + + def _Database(self): + db = symbol_database.SymbolDatabase() + # Register representative types from unittest_pb2. + db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) + db.RegisterMessage(unittest_pb2.TestAllTypes) + db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) + db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) + db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) + db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + return db + + def testGetPrototype(self): + instance = self._Database().GetPrototype( + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertTrue(instance is unittest_pb2.TestAllTypes) + + def testGetMessages(self): + messages = self._Database().GetMessages( + ['google/protobuf/unittest.proto']) + self.assertTrue( + unittest_pb2.TestAllTypes is + messages['protobuf_unittest.TestAllTypes']) + + def testGetSymbol(self): + self.assertEquals( + unittest_pb2.TestAllTypes, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes')) + self.assertEquals( + unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.NestedMessage')) + self.assertEquals( + unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.OptionalGroup')) + self.assertEquals( + unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.RepeatedGroup')) + + def testEnums(self): + # Check registration of types in the pool. + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.ForeignEnum').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + def testFindMessageTypeByName(self): + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + def testFindFindContainingSymbol(self): + # Lookup based on either enum or message. + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes').name) + + def testFindFileByName(self): + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index be8ae7be..350d1c6d 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -66,14 +66,8 @@ def SetAllNonLazyFields(message): message.optional_float = 111 message.optional_double = 112 message.optional_bool = True - # TODO(robinson): Firmly spec out and test how - # protos interact with unicode. One specific example: - # what happens if we change the literal below to - # u'115'? What *should* happen? Still some discussion - # to finish with Kenton about bytes vs. strings - # and forcing everything to be utf8. :-/ - message.optional_string = '115' - message.optional_bytes = '116' + message.optional_string = u'115' + message.optional_bytes = b'116' message.optionalgroup.a = 117 message.optional_nested_message.bb = 118 @@ -85,8 +79,8 @@ def SetAllNonLazyFields(message): message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ - message.optional_string_piece = '124' - message.optional_cord = '125' + message.optional_string_piece = u'124' + message.optional_cord = u'125' # # Repeated fields. @@ -105,8 +99,8 @@ def SetAllNonLazyFields(message): message.repeated_float.append(211) message.repeated_double.append(212) message.repeated_bool.append(True) - message.repeated_string.append('215') - message.repeated_bytes.append('216') + message.repeated_string.append(u'215') + message.repeated_bytes.append(b'216') message.repeatedgroup.add().a = 217 message.repeated_nested_message.add().bb = 218 @@ -118,8 +112,8 @@ def SetAllNonLazyFields(message): message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) - message.repeated_string_piece.append('224') - message.repeated_cord.append('225') + message.repeated_string_piece.append(u'224') + message.repeated_cord.append(u'225') # Add a second one of each field. message.repeated_int32.append(301) @@ -135,8 +129,8 @@ def SetAllNonLazyFields(message): message.repeated_float.append(311) message.repeated_double.append(312) message.repeated_bool.append(False) - message.repeated_string.append('315') - message.repeated_bytes.append('316') + message.repeated_string.append(u'315') + message.repeated_bytes.append(b'316') message.repeatedgroup.add().a = 317 message.repeated_nested_message.add().bb = 318 @@ -148,8 +142,8 @@ def SetAllNonLazyFields(message): message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) - message.repeated_string_piece.append('324') - message.repeated_cord.append('325') + message.repeated_string_piece.append(u'324') + message.repeated_cord.append(u'325') # # Fields that have defaults. @@ -169,7 +163,7 @@ def SetAllNonLazyFields(message): message.default_double = 412 message.default_bool = False message.default_string = '415' - message.default_bytes = '416' + message.default_bytes = b'416' message.default_nested_enum = unittest_pb2.TestAllTypes.FOO message.default_foreign_enum = unittest_pb2.FOREIGN_FOO @@ -178,6 +172,11 @@ def SetAllNonLazyFields(message): message.default_string_piece = '424' message.default_cord = '425' + message.oneof_uint32 = 601 + message.oneof_nested_message.bb = 602 + message.oneof_string = '603' + message.oneof_bytes = b'604' + def SetAllFields(message): SetAllNonLazyFields(message) @@ -212,8 +211,8 @@ def SetAllExtensions(message): extensions[pb2.optional_float_extension] = 111 extensions[pb2.optional_double_extension] = 112 extensions[pb2.optional_bool_extension] = True - extensions[pb2.optional_string_extension] = '115' - extensions[pb2.optional_bytes_extension] = '116' + extensions[pb2.optional_string_extension] = u'115' + extensions[pb2.optional_bytes_extension] = b'116' extensions[pb2.optionalgroup_extension].a = 117 extensions[pb2.optional_nested_message_extension].bb = 118 @@ -227,8 +226,8 @@ def SetAllExtensions(message): extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ - extensions[pb2.optional_string_piece_extension] = '124' - extensions[pb2.optional_cord_extension] = '125' + extensions[pb2.optional_string_piece_extension] = u'124' + extensions[pb2.optional_cord_extension] = u'125' # # Repeated fields. @@ -247,8 +246,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(211) extensions[pb2.repeated_double_extension].append(212) extensions[pb2.repeated_bool_extension].append(True) - extensions[pb2.repeated_string_extension].append('215') - extensions[pb2.repeated_bytes_extension].append('216') + extensions[pb2.repeated_string_extension].append(u'215') + extensions[pb2.repeated_bytes_extension].append(b'216') extensions[pb2.repeatedgroup_extension].add().a = 217 extensions[pb2.repeated_nested_message_extension].add().bb = 218 @@ -260,8 +259,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR) - extensions[pb2.repeated_string_piece_extension].append('224') - extensions[pb2.repeated_cord_extension].append('225') + extensions[pb2.repeated_string_piece_extension].append(u'224') + extensions[pb2.repeated_cord_extension].append(u'225') # Append a second one of each field. extensions[pb2.repeated_int32_extension].append(301) @@ -277,8 +276,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(311) extensions[pb2.repeated_double_extension].append(312) extensions[pb2.repeated_bool_extension].append(False) - extensions[pb2.repeated_string_extension].append('315') - extensions[pb2.repeated_bytes_extension].append('316') + extensions[pb2.repeated_string_extension].append(u'315') + extensions[pb2.repeated_bytes_extension].append(b'316') extensions[pb2.repeatedgroup_extension].add().a = 317 extensions[pb2.repeated_nested_message_extension].add().bb = 318 @@ -290,8 +289,8 @@ def SetAllExtensions(message): extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ) - extensions[pb2.repeated_string_piece_extension].append('324') - extensions[pb2.repeated_cord_extension].append('325') + extensions[pb2.repeated_string_piece_extension].append(u'324') + extensions[pb2.repeated_cord_extension].append(u'325') # # Fields with defaults. @@ -310,16 +309,21 @@ def SetAllExtensions(message): extensions[pb2.default_float_extension] = 411 extensions[pb2.default_double_extension] = 412 extensions[pb2.default_bool_extension] = False - extensions[pb2.default_string_extension] = '415' - extensions[pb2.default_bytes_extension] = '416' + extensions[pb2.default_string_extension] = u'415' + extensions[pb2.default_bytes_extension] = b'416' extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO - extensions[pb2.default_string_piece_extension] = '424' + extensions[pb2.default_string_piece_extension] = u'424' extensions[pb2.default_cord_extension] = '425' + extensions[pb2.oneof_uint32_extension] = 601 + extensions[pb2.oneof_nested_message_extension].bb = 602 + extensions[pb2.oneof_string_extension] = u'603' + extensions[pb2.oneof_bytes_extension] = b'604' + def SetAllFieldsAndExtensions(message): """Sets every field and extension in the message to a unique value. @@ -358,7 +362,7 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized): message.my_float = 1.0 expected_strings.append(message.SerializeToString()) message.Clear() - expected = ''.join(expected_strings) + expected = b''.join(expected_strings) if expected != serialized: raise ValueError('Expected %r, found %r' % (expected, serialized)) @@ -413,7 +417,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(112, message.optional_double) test_case.assertEqual(True, message.optional_bool) test_case.assertEqual('115', message.optional_string) - test_case.assertEqual('116', message.optional_bytes) + test_case.assertEqual(b'116', message.optional_bytes) test_case.assertEqual(117, message.optionalgroup.a) test_case.assertEqual(118, message.optional_nested_message.bb) @@ -472,7 +476,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(212, message.repeated_double[0]) test_case.assertEqual(True, message.repeated_bool[0]) test_case.assertEqual('215', message.repeated_string[0]) - test_case.assertEqual('216', message.repeated_bytes[0]) + test_case.assertEqual(b'216', message.repeated_bytes[0]) test_case.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(218, message.repeated_nested_message[0].bb) @@ -501,7 +505,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(312, message.repeated_double[1]) test_case.assertEqual(False, message.repeated_bool[1]) test_case.assertEqual('315', message.repeated_string[1]) - test_case.assertEqual('316', message.repeated_bytes[1]) + test_case.assertEqual(b'316', message.repeated_bytes[1]) test_case.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(318, message.repeated_nested_message[1].bb) @@ -552,7 +556,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(412, message.default_double) test_case.assertEqual(False, message.default_bool) test_case.assertEqual('415', message.default_string) - test_case.assertEqual('416', message.default_bytes) + test_case.assertEqual(b'416', message.default_bytes) test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum) @@ -561,6 +565,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, message.default_import_enum) + def GoldenFile(filename): """Finds the given golden file and returns a file object representing it.""" @@ -574,9 +579,15 @@ def GoldenFile(filename): path = os.path.join(path, '..') raise RuntimeError( - 'Could not find golden files. This test must be run from within the ' - 'protobuf source package so that it can read test data files from the ' - 'C++ source tree.') + 'Could not find golden files. This test must be run from within the ' + 'protobuf source package so that it can read test data files from the ' + 'C++ source tree.') + + +def GoldenFileData(filename): + """Finds the given golden file and returns its contents.""" + with GoldenFile(filename) as f: + return f.read() def SetAllPackedFields(message): diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py new file mode 100755 index 00000000..ba0e45d6 --- /dev/null +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -0,0 +1,68 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.text_encoding.""" + +from google.apputils import basetest +from google.protobuf import text_encoding + +TEST_VALUES = [ + ("foo\\rbar\\nbaz\\t", + "foo\\rbar\\nbaz\\t", + b"foo\rbar\nbaz\t"), + ("\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + "\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + b"'full of \"sound\" and \"fury\"'"), + ("signi\\\\fying\\\\ nothing\\\\", + "signi\\\\fying\\\\ nothing\\\\", + b"signi\\fying\\ nothing\\"), + ("\\010\\t\\n\\013\\014\\r", + "\x08\\t\\n\x0b\x0c\\r", + b"\010\011\012\013\014\015")] + + +class TextEncodingTestCase(basetest.TestCase): + def testCEscape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(escaped, + text_encoding.CEscape(unescaped, as_utf8=False)) + self.assertEquals(escaped_utf8, + text_encoding.CEscape(unescaped, as_utf8=True)) + + def testCUnescape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(unescaped, text_encoding.CUnescape(escaped)) + self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8)) + + +if __name__ == "__main__": + basetest.main() diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 4b1b4f59..a85d4dd8 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -34,49 +34,71 @@ __author__ = 'kenton@google.com (Kenton Varda)' -import difflib import re -import unittest +from google.apputils import basetest from google.protobuf import text_format +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf import unittest_pb2 from google.protobuf import unittest_mset_pb2 +class TextFormatTest(basetest.TestCase): -class TextFormatTest(unittest.TestCase): def ReadGolden(self, golden_filename): - f = test_util.GoldenFile(golden_filename) - golden_lines = f.readlines() - f.close() - return golden_lines + with test_util.GoldenFile(golden_filename) as f: + return (f.readlines() if str is bytes else # PY3 + [golden_line.decode('utf-8') for golden_line in f]) def CompareToGoldenFile(self, text, golden_filename): golden_lines = self.ReadGolden(golden_filename) - self.CompareToGoldenLines(text, golden_lines) + self.assertMultiLineEqual(text, ''.join(golden_lines)) def CompareToGoldenText(self, text, golden_text): - self.CompareToGoldenLines(text, golden_text.splitlines(1)) - - def CompareToGoldenLines(self, text, golden_lines): - actual_lines = text.splitlines(1) - self.assertEqual(golden_lines, actual_lines, - "Text doesn't match golden. Diff:\n" + - ''.join(difflib.ndiff(golden_lines, actual_lines))) + self.assertMultiLineEqual(text, golden_text) def testPrintAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data_oneof_implemented.txt') + + def testPrintInIndexOrder(self): + message = unittest_pb2.TestFieldOrderings() + message.my_string = '115' + message.my_int = 101 + message.my_float = 111 + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message, use_index_order=True)), + 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n') def testPrintAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_extensions_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintAllFieldsPointy(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof_implemented.txt') + + def testPrintAllExtensionsPointy(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString( + message, pointy_brackets=True)), + 'text_format_unittest_extensions_data_pointy.txt') def testPrintMessageSet(self): message = unittest_mset_pb2.TestMessageSetContainer() @@ -84,37 +106,16 @@ class TextFormatTest(unittest.TestCase): ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension message.message_set.Extensions[ext1].i = 23 message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText(text_format.MessageToString(message), - 'message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') - - def testPrintBadEnumValue(self): - message = unittest_pb2.TestAllTypes() - message.optional_nested_enum = 100 - message.optional_foreign_enum = 101 - message.optional_import_enum = 102 - self.CompareToGoldenText( - text_format.MessageToString(message), - 'optional_nested_enum: 100\n' - 'optional_foreign_enum: 101\n' - 'optional_import_enum: 102\n') - - def testPrintBadEnumValueExtensions(self): - message = unittest_pb2.TestAllExtensions() - message.Extensions[unittest_pb2.optional_nested_enum_extension] = 100 - message.Extensions[unittest_pb2.optional_foreign_enum_extension] = 101 - message.Extensions[unittest_pb2.optional_import_enum_extension] = 102 self.CompareToGoldenText( text_format.MessageToString(message), - '[protobuf_unittest.optional_nested_enum_extension]: 100\n' - '[protobuf_unittest.optional_foreign_enum_extension]: 101\n' - '[protobuf_unittest.optional_import_enum_extension]: 102\n') + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') def testPrintExotic(self): message = unittest_pb2.TestAllTypes() @@ -126,20 +127,29 @@ class TextFormatTest(unittest.TestCase): message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u'\u00fc\ua71f') self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'repeated_int64: -9223372036854775808\n' - 'repeated_uint64: 18446744073709551615\n' - 'repeated_double: 123.456\n' - 'repeated_double: 1.23e+22\n' - 'repeated_double: 1.23e-18\n' - 'repeated_string: ' - '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' - 'repeated_string: "\\303\\274\\352\\234\\237"\n') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string:' + ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' + 'repeated_string: "\\303\\274\\352\\234\\237"\n') + + def testPrintExoticUnicodeSubclass(self): + class UnicodeSub(unicode): + pass + message = unittest_pb2.TestAllTypes() + message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) + self.CompareToGoldenText( + text_format.MessageToString(message), + 'repeated_string: "\\303\\274\\352\\234\\237"\n') def testPrintNestedMessageAsOneLine(self): message = unittest_pb2.TestAllTypes() msg = message.repeated_nested_message.add() - msg.bb = 42; + msg.bb = 42 self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'repeated_nested_message { bb: 42 }') @@ -190,16 +200,16 @@ class TextFormatTest(unittest.TestCase): message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u'\u00fc\ua71f') self.CompareToGoldenText( - self.RemoveRedundantZeros( - text_format.MessageToString(message, as_one_line=True)), - 'repeated_int64: -9223372036854775808' - ' repeated_uint64: 18446744073709551615' - ' repeated_double: 123.456' - ' repeated_double: 1.23e+22' - ' repeated_double: 1.23e-18' - ' repeated_string: ' - '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' - ' repeated_string: "\\303\\274\\352\\234\\237"') + self.RemoveRedundantZeros( + text_format.MessageToString(message, as_one_line=True)), + 'repeated_int64: -9223372036854775808' + ' repeated_uint64: 18446744073709551615' + ' repeated_double: 123.456' + ' repeated_double: 1.23e+22' + ' repeated_double: 1.23e-18' + ' repeated_string: ' + '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' + ' repeated_string: "\\303\\274\\352\\234\\237"') def testRoundTripExoticAsOneLine(self): message = unittest_pb2.TestAllTypes() @@ -215,24 +225,60 @@ class TextFormatTest(unittest.TestCase): wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=False) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(wire_text, parsed_message) + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) self.assertEquals(message, parsed_message) # Test as_utf8 = True. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=True) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(wire_text, parsed_message) - self.assertEquals(message, parsed_message) + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) def testPrintRawUtf8String(self): message = unittest_pb2.TestAllTypes() message.repeated_string.append(u'\u00fc\ua71f') - text = text_format.MessageToString(message, as_utf8 = True) + text = text_format.MessageToString(message, as_utf8=True) self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(text, parsed_message) - self.assertEquals(message, parsed_message) + text_format.Parse(text, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) + + def testPrintFloatFormat(self): + # Check that float_format argument is passed to sub-message formatting. + message = unittest_pb2.NestedTestAllTypes() + # We use 1.25 as it is a round number in binary. The proto 32-bit float + # will not gain additional imprecise digits as a 64-bit Python float and + # show up in its str. 32-bit 1.2 is noisy when extended to 64-bit: + # >>> struct.unpack('f', struct.pack('f', 1.2))[0] + # 1.2000000476837158 + # >>> struct.unpack('f', struct.pack('f', 1.25))[0] + # 1.25 + message.payload.optional_float = 1.25 + # Check rounding at 15 significant digits + message.payload.optional_double = -.000003456789012345678 + # Check no decimal point. + message.payload.repeated_float.append(-5642) + # Check no trailing zeros. + message.payload.repeated_double.append(.000078900) + formatted_fields = ['optional_float: 1.25', + 'optional_double: -3.45678901234568e-6', + 'repeated_float: -5642', + 'repeated_double: 7.89e-5'] + text_message = text_format.MessageToString(message, float_format='.15g') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_message), + 'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields)) + # as_one_line=True is a separate code branch where float_format is passed. + text_message = text_format.MessageToString(message, as_one_line=True, + float_format='.15g') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_message), + 'payload {{ {} {} {} {} }}'.format(*formatted_fields)) def testMessageToString(self): message = unittest_pb2.ForeignMessage() @@ -249,49 +295,50 @@ class TextFormatTest(unittest.TestCase): text = re.compile('\.0$', re.MULTILINE).sub('', text) return text - def testMergeGolden(self): + def testParseGolden(self): golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(golden_text, parsed_message) + r = text_format.Parse(golden_text, parsed_message) + self.assertIs(r, parsed_message) message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.assertEquals(message, parsed_message) - def testMergeGoldenExtensions(self): + def testParseGoldenExtensions(self): golden_text = '\n'.join(self.ReadGolden( 'text_format_unittest_extensions_data.txt')) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(golden_text, parsed_message) + text_format.Parse(golden_text, parsed_message) message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.assertEquals(message, parsed_message) - def testMergeAllFields(self): + def testParseAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) test_util.ExpectAllFieldsSet(self, message) - def testMergeAllExtensions(self): + def testParseAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) - def testMergeMessageSet(self): + def testParseMessageSet(self): message = unittest_pb2.TestAllTypes() text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n') - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) @@ -304,13 +351,13 @@ class TextFormatTest(unittest.TestCase): ' str: \"foo\"\n' ' }\n' '}\n') - text_format.Merge(text, message) + text_format.Parse(text, message) ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension self.assertEquals(23, message.message_set.Extensions[ext1].i) self.assertEquals('foo', message.message_set.Extensions[ext2].str) - def testMergeExotic(self): + def testParseExotic(self): message = unittest_pb2.TestAllTypes() text = ('repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' @@ -323,7 +370,7 @@ class TextFormatTest(unittest.TestCase): 'repeated_string: "\\303\\274\\352\\234\\237"\n' 'repeated_string: "\\xc3\\xbc"\n' 'repeated_string: "\xc3\xbc"\n') - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEqual(-9223372036854775808, message.repeated_int64[0]) self.assertEqual(18446744073709551615, message.repeated_uint64[0]) @@ -336,100 +383,115 @@ class TextFormatTest(unittest.TestCase): self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2]) self.assertEqual(u'\u00fc', message.repeated_string[3]) - def testMergeEmptyText(self): + def testParseTrailingCommas(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_int64: 100;\n' + 'repeated_int64: 200;\n' + 'repeated_int64: 300,\n' + 'repeated_string: "one",\n' + 'repeated_string: "two";\n') + text_format.Parse(text, message) + + self.assertEqual(100, message.repeated_int64[0]) + self.assertEqual(200, message.repeated_int64[1]) + self.assertEqual(300, message.repeated_int64[2]) + self.assertEqual(u'one', message.repeated_string[0]) + self.assertEqual(u'two', message.repeated_string[1]) + + def testParseEmptyText(self): message = unittest_pb2.TestAllTypes() text = '' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEquals(unittest_pb2.TestAllTypes(), message) - def testMergeInvalidUtf8(self): + def testParseInvalidUtf8(self): message = unittest_pb2.TestAllTypes() text = 'repeated_string: "\\xc3\\xc3"' - self.assertRaises(text_format.ParseError, text_format.Merge, text, message) + self.assertRaises(text_format.ParseError, text_format.Parse, text, message) - def testMergeSingleWord(self): + def testParseSingleWord(self): message = unittest_pb2.TestAllTypes() text = 'foo' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' '"foo".'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeUnknownField(self): + def testParseUnknownField(self): message = unittest_pb2.TestAllTypes() text = 'unknown_field: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' '"unknown_field".'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeBadExtension(self): + def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:2 : Extension "unknown_extension" not registered.', - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' 'extensions.'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeGroupNotClosed(self): + def testParseGroupNotClosed(self): message = unittest_pb2.TestAllTypes() text = 'RepeatedGroup: <' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected ">".', - text_format.Merge, text, message) + text_format.Parse, text, message) text = 'RepeatedGroup: {' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected "}".', - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeEmptyGroup(self): + def testParseEmptyGroup(self): message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: {}' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) message.Clear() message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: <>' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) - def testMergeBadEnumValue(self): + def testParseBadEnumValue(self): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: BARR' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value named BARR.'), - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: 100' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value with number 100.'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeBadIntValue(self): + def testParseBadIntValue(self): message = unittest_pb2.TestAllTypes() text = 'optional_int32: bork' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:17 : Couldn\'t parse integer: bork'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeStringFieldUnescape(self): + def testParseStringFieldUnescape(self): message = unittest_pb2.TestAllTypes() text = r'''repeated_string: "\xf\x62" repeated_string: "\\xf\\x62" @@ -437,7 +499,7 @@ class TextFormatTest(unittest.TestCase): repeated_string: "\\\\xf\\\\x62" repeated_string: "\\\\\xf\\\\\x62" repeated_string: "\x5cx20"''' - text_format.Merge(text, message) + text_format.Parse(text, message) SLASH = '\\' self.assertEqual('\x0fb', message.repeated_string[0]) @@ -449,27 +511,84 @@ class TextFormatTest(unittest.TestCase): message.repeated_string[4]) self.assertEqual(SLASH + 'x20', message.repeated_string[5]) - def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): - """Same as assertRaises, but also compares the exception message.""" - if hasattr(e_class, '__name__'): - exc_name = e_class.__name__ - else: - exc_name = str(e_class) + def testMergeRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + r = text_format.Merge(text, message) + self.assertIs(r, message) + self.assertEqual(67, message.optional_int32) + + def testParseRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' + 'have multiple "optional_int32" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + r = text_format.Merge(text, message) + self.assertTrue(r is message) + self.assertEqual(2, message.optional_nested_message.bb) + + def testParseRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' + 'should not have multiple "bb" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + text_format.Merge(text, message) + self.assertEqual( + 67, + message.Extensions[unittest_pb2.optional_int32_extension]) + + def testParseRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' + 'should not have multiple ' + '"protobuf_unittest.optional_int32_extension" extensions.'), + text_format.Parse, text, message) + + def testParseLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.ParseLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + def testMergeLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.MergeLines(opened, parsed_message) + self.assertIs(r, parsed_message) - try: - func(*args, **kwargs) - except e_class as expr: - if str(expr) != e: - msg = '%s raised, but with wrong message: "%s" instead of "%s"' - raise self.failureException(msg % (exc_name, - str(expr).encode('string_escape'), - e.encode('string_escape'))) - return - else: - raise self.failureException('%s not raised' % exc_name) + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) -class TokenizerTest(unittest.TestCase): +class TokenizerTest(basetest.TestCase): def testSimpleTokenCases(self): text = ('identifier1:"string1"\n \n\n' @@ -478,8 +597,8 @@ class TokenizerTest(unittest.TestCase): 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' - 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ' ) - tokenizer = text_format._Tokenizer(text) + 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ') + tokenizer = text_format._Tokenizer(text.splitlines()) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':', (tokenizer.ConsumeString, 'string1'), @@ -565,7 +684,7 @@ class TokenizerTest(unittest.TestCase): int64_max = (1 << 63) - 1 uint32_max = (1 << 32) - 1 text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64) self.assertEqual(-1, tokenizer.ConsumeInt32()) @@ -579,7 +698,7 @@ class TokenizerTest(unittest.TestCase): self.assertTrue(tokenizer.AtEnd()) text = '-0 -0 0 0' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertEqual(0, tokenizer.ConsumeUint32()) self.assertEqual(0, tokenizer.ConsumeUint64()) self.assertEqual(0, tokenizer.ConsumeUint32()) @@ -588,30 +707,30 @@ class TokenizerTest(unittest.TestCase): def testConsumeByteString(self): text = '"string1\'' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = 'string1"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\xt"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\x"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) def testConsumeBool(self): text = 'not-a-bool' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 2b3cd4de..8e1b3cc3 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2008 Google Inc. All Rights Reserved. + """Provides type checking routines. This module defines type checking utilities in the forms of dictionaries: @@ -45,6 +49,9 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization __author__ = 'robinson@google.com (Will Robinson)' +import sys ##PY25 +if sys.version < '2.6': bytes = str ##PY25 +from google.protobuf.internal import api_implementation from google.protobuf.internal import decoder from google.protobuf.internal import encoder from google.protobuf.internal import wire_format @@ -53,21 +60,22 @@ from google.protobuf import descriptor _FieldDescriptor = descriptor.FieldDescriptor -def GetTypeChecker(cpp_type, field_type): +def GetTypeChecker(field): """Returns a type checker for a message field of the specified types. Args: - cpp_type: C++ type of the field (see descriptor.py). - field_type: Protocol message field type (see descriptor.py). + field: FieldDescriptor object for this field. Returns: An instance of TypeChecker which can be used to verify the types of values assigned to a field of the specified type. """ - if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and - field_type == _FieldDescriptor.TYPE_STRING): + if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field.type == _FieldDescriptor.TYPE_STRING): return UnicodeValueChecker() - return _VALUE_CHECKERS[cpp_type] + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + return EnumValueChecker(field.enum_type) + return _VALUE_CHECKERS[field.cpp_type] # None of the typecheckers below make any attempt to guard against people @@ -85,10 +93,15 @@ class TypeChecker(object): self._acceptable_types = acceptable_types def CheckValue(self, proposed_value): + """Type check the provided value and return it. + + The returned value might have been normalized to another type. + """ if not isinstance(proposed_value, self._acceptable_types): message = ('%.1024r has type %s, but expected one of: %s' % (proposed_value, type(proposed_value), self._acceptable_types)) raise TypeError(message) + return proposed_value # IntValueChecker and its subclasses perform integer type-checks @@ -104,28 +117,54 @@ class IntValueChecker(object): raise TypeError(message) if not self._MIN <= proposed_value <= self._MAX: raise ValueError('Value out of range: %d' % proposed_value) + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + proposed_value = self._TYPE(proposed_value) + return proposed_value + + +class EnumValueChecker(object): + + """Checker used for enum fields. Performs type-check and range check.""" + + def __init__(self, enum_type): + self._enum_type = enum_type + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if proposed_value not in self._enum_type.values_by_number: + raise ValueError('Unknown enum value: %d' % proposed_value) + return proposed_value class UnicodeValueChecker(object): - """Checker used for string fields.""" + """Checker used for string fields. + + Always returns a unicode value, even if the input is of type str. + """ def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (str, unicode)): + if not isinstance(proposed_value, (bytes, unicode)): message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (str, unicode))) + (proposed_value, type(proposed_value), (bytes, unicode))) raise TypeError(message) - # If the value is of type 'str' make sure that it is in 7-bit ASCII + # If the value is of type 'bytes' make sure that it is in 7-bit ASCII # encoding. - if isinstance(proposed_value, str): + if isinstance(proposed_value, bytes): try: - unicode(proposed_value, 'ascii') + proposed_value = proposed_value.decode('ascii') except UnicodeDecodeError: - raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII ' + raise ValueError('%.1024r has type bytes, but isn\'t in 7-bit ASCII ' 'encoding. Non-ASCII strings must be converted to ' 'unicode objects before being added.' % (proposed_value)) + return proposed_value class Int32ValueChecker(IntValueChecker): @@ -133,21 +172,25 @@ class Int32ValueChecker(IntValueChecker): # efficient. _MIN = -2147483648 _MAX = 2147483647 + _TYPE = int class Uint32ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 32) - 1 + _TYPE = int class Int64ValueChecker(IntValueChecker): _MIN = -(1 << 63) _MAX = (1 << 63) - 1 + _TYPE = long class Uint64ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 64) - 1 + _TYPE = long # Type-checkers for all scalar CPPTYPEs. @@ -161,8 +204,7 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( float, int, long), _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(), - _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str), + _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes), } diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 84984b40..8f3354c9 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -35,15 +35,16 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf.internal import encoder +from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import type_checkers -class UnknownFieldsTest(unittest.TestCase): +class UnknownFieldsTest(basetest.TestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -58,12 +59,20 @@ class UnknownFieldsTest(unittest.TestCase): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} for tag_bytes, value in self.unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes] - result_dict = {} decoder(value, 0, len(value), self.all_fields, result_dict) - return result_dict[field_descriptor] + return result_dict[field_descriptor] + + def testEnum(self): + value = self.GetField('optional_nested_enum') + self.assertEqual(self.all_fields.optional_nested_enum, value) + + def testRepeatedEnum(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.all_fields.repeated_nested_enum, value) def testVarint(self): value = self.GetField('optional_int32') @@ -166,5 +175,57 @@ class UnknownFieldsTest(unittest.TestCase): self.assertNotEqual(self.empty_message, message) +class UnknownFieldsTest(basetest.TestCase): + + def setUp(self): + self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR + + self.message = missing_enum_values_pb2.TestEnumValues() + self.message.optional_nested_enum = ( + missing_enum_values_pb2.TestEnumValues.ZERO) + self.message.repeated_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message.packed_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message_data = self.message.SerializeToString() + self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() + self.missing_message.ParseFromString(self.message_data) + self.unknown_fields = self.missing_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ + tag_bytes] + decoder(value, 0, len(value), self.message, result_dict) + return result_dict[field_descriptor] + + def testUnknownEnumValue(self): + self.assertFalse(self.missing_message.HasField('optional_nested_enum')) + value = self.GetField('optional_nested_enum') + self.assertEqual(self.message.optional_nested_enum, value) + + def testUnknownRepeatedEnumValue(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.message.repeated_nested_enum, value) + + def testUnknownPackedEnumValue(self): + value = self.GetField('packed_nested_enum') + self.assertEqual(self.message.packed_nested_enum, value) + + def testRoundTrip(self): + new_message = missing_enum_values_pb2.TestEnumValues() + new_message.ParseFromString(self.missing_message.SerializeToString()) + self.assertEqual(self.message, new_message) + + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index 76007786..9362c72d 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -34,12 +34,12 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest from google.protobuf import message from google.protobuf.internal import wire_format -class WireFormatTest(unittest.TestCase): +class WireFormatTest(basetest.TestCase): def testPackTag(self): field_number = 0xabc @@ -195,7 +195,7 @@ class WireFormatTest(unittest.TestCase): # Test UTF-8 string byte size calculation. # 1 byte for tag, 1 byte for length, 8 bytes for content. self.assertEqual(10, wire_format.StringByteSize( - 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'))) + 5, b'\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'.decode('utf-8'))) class MockMessage(object): def __init__(self, byte_size): @@ -250,4 +250,4 @@ class WireFormatTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() |