diff options
author | Adam Cozzette <acozzette@google.com> | 2016-11-17 16:48:38 -0800 |
---|---|---|
committer | Adam Cozzette <acozzette@google.com> | 2016-11-17 16:59:59 -0800 |
commit | 5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74 (patch) | |
tree | 0276f81f8848a05d84cd7e287b43d665e30f04e3 /python/google/protobuf/pyext | |
parent | e28286fa05d8327fd6c5aa70cfb3be558f0932b8 (diff) |
Integrated internal changes from Google
Diffstat (limited to 'python/google/protobuf/pyext')
-rw-r--r-- | python/google/protobuf/pyext/descriptor_pool.cc | 67 | ||||
-rw-r--r-- | python/google/protobuf/pyext/extension_dict.cc | 77 | ||||
-rw-r--r-- | python/google/protobuf/pyext/map_container.cc | 2 | ||||
-rw-r--r-- | python/google/protobuf/pyext/message.cc | 409 | ||||
-rw-r--r-- | python/google/protobuf/pyext/message.h | 24 | ||||
-rw-r--r-- | python/google/protobuf/pyext/message_factory.cc | 66 | ||||
-rw-r--r-- | python/google/protobuf/pyext/message_factory.h | 12 | ||||
-rw-r--r-- | python/google/protobuf/pyext/safe_numerics.h | 164 |
8 files changed, 577 insertions, 244 deletions
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index a42e5431..fa66bf9a 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -319,6 +319,51 @@ PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) { return PyFileDescriptor_FromDescriptor(file_descriptor); } +PyObject* FindExtensionByNumber(PyDescriptorPool* self, PyObject* args) { + PyObject* message_descriptor; + int number; + if (!PyArg_ParseTuple(args, "Oi", &message_descriptor, &number)) { + return NULL; + } + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor( + message_descriptor); + if (descriptor == NULL) { + return NULL; + } + + const FieldDescriptor* extension_descriptor = + self->pool->FindExtensionByNumber(descriptor, number); + if (extension_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find extension %d", number); + return NULL; + } + + return PyFieldDescriptor_FromDescriptor(extension_descriptor); +} + +PyObject* FindAllExtensions(PyDescriptorPool* self, PyObject* arg) { + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(arg); + if (descriptor == NULL) { + return NULL; + } + + std::vector<const FieldDescriptor*> extensions; + self->pool->FindAllExtensions(descriptor, &extensions); + + ScopedPyObjectPtr result(PyList_New(extensions.size())); + if (result == NULL) { + return NULL; + } + for (int i = 0; i < extensions.size(); i++) { + PyObject* extension = PyFieldDescriptor_FromDescriptor(extensions[i]); + if (extension == NULL) { + return NULL; + } + PyList_SET_ITEM(result.get(), i, extension); // Steals the reference. + } + return result.release(); +} + // These functions should not exist -- the only valid way to create // descriptors is to call Add() or AddSerializedFile(). // But these AddDescriptor() functions were created in Python and some people @@ -376,6 +421,22 @@ PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) { Py_RETURN_NONE; } +PyObject* AddExtensionDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const FieldDescriptor* extension_descriptor = + PyFieldDescriptor_AsDescriptor(descriptor); + if (!extension_descriptor) { + return NULL; + } + if (extension_descriptor != + self->pool->FindExtensionByName(extension_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The extension descriptor %s does not belong to this pool", + extension_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + // The code below loads new Descriptors from a serialized FileDescriptorProto. @@ -475,6 +536,8 @@ static PyMethodDef Methods[] = { "No-op. Add() must have been called before." }, { "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O, "No-op. Add() must have been called before." }, + { "AddExtensionDescriptor", (PyCFunction)AddExtensionDescriptor, METH_O, + "No-op. Add() must have been called before." }, { "FindFileByName", (PyCFunction)FindFileByName, METH_O, "Searches for a file descriptor by its .proto name." }, @@ -495,6 +558,10 @@ static PyMethodDef Methods[] = { { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O, "Gets the FileDescriptor containing the specified symbol." }, + { "FindExtensionByNumber", (PyCFunction)FindExtensionByNumber, METH_VARARGS, + "Gets the extension descriptor for the given number." }, + { "FindAllExtensions", (PyCFunction)FindAllExtensions, METH_O, + "Gets all known extensions of the given message descriptor." }, {NULL} }; diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index dbb7bca0..9423c1d8 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -38,6 +38,7 @@ #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> +#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/pyext/descriptor.h> #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/message_factory.h> @@ -46,6 +47,16 @@ #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/shared_ptr.h> +#if PY_MAJOR_VERSION >= 3 + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + namespace google { namespace protobuf { namespace python { @@ -90,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // TODO(plabatut): consider building the class on the fly! PyObject* sub_message = cmessage::InternalGetSubMessage( self->parent, descriptor); if (sub_message == NULL) { @@ -101,7 +113,17 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* message_class = message_factory::GetMessageClass( + // On the fly message class creation is needed to support the following + // situation: + // 1- add FileDescriptor to the pool that contains extensions of a message + // defined by another proto file. Do not create any message classes. + // 2- instantiate an extended message, and access the extension using + // the field descriptor. + // 3- the extension submessage fails to be returned, because no class has + // been created. + // It happens when deserializing text proto format, or when enumerating + // fields of a deserialized message. + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( cmessage::GetFactoryForMessage(self->parent), descriptor->message_type()); if (message_class == NULL) { @@ -154,34 +176,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { return 0; } -PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { - ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( - reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); - if (extensions_by_name == NULL) { +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); - if (result == NULL) { + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = + pool->pool->FindExtensionByName(string(name, name_size)); + if (message_extension == NULL) { + // Is is the name of a message set extension? + const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName( + string(name, name_size)); + if (message_descriptor && message_descriptor->extension_count() > 0) { + const FieldDescriptor* extension = message_descriptor->extension(0); + if (extension->is_extension() && + extension->containing_type()->options().message_set_wire_format() && + extension->type() == FieldDescriptor::TYPE_MESSAGE && + extension->label() == FieldDescriptor::LABEL_OPTIONAL) { + message_extension = extension; + } + } + } + if (message_extension == NULL) { Py_RETURN_NONE; - } else { - Py_INCREF(result); - return result; } + + return PyFieldDescriptor_FromDescriptor(message_extension); } -PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { - ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( - reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number")); - if (extensions_by_number == NULL) { +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) { + int64 number = PyLong_AsLong(arg); + if (number == -1 && PyErr_Occurred()) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); - if (result == NULL) { + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber( + self->parent->message->GetDescriptor(), number); + if (message_extension == NULL) { Py_RETURN_NONE; - } else { - Py_INCREF(result); - return result; } + + return PyFieldDescriptor_FromDescriptor(message_extension); } ExtensionDict* NewExtensionDict(CMessage *parent) { diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 318c2e7c..088ddf93 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -374,7 +374,7 @@ static int InitializeAndCopyToParentContainer(MapContainer* from, // A somewhat roundabout way of copying just one field from old_message to // new_message. This is the best we can do with what Reflection gives us. Message* mutable_old = from->GetMutableMessage(); - vector<const FieldDescriptor*> fields; + std::vector<const FieldDescriptor*> fields; fields.push_back(from->parent_field_descriptor); // Move the map field into the new message. diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 7ff99aea..5967a587 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -64,11 +64,11 @@ #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/safe_numerics.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> #if PY_MAJOR_VERSION >= 3 - #define PyInt_Check PyLong_Check #define PyInt_AsLong PyLong_AsLong #define PyInt_FromLong PyLong_FromLong #define PyInt_FromSize_t PyLong_FromSize_t @@ -92,8 +92,6 @@ namespace protobuf { namespace python { static PyObject* kDESCRIPTOR; -static PyObject* k_extensions_by_name; -static PyObject* k_extensions_by_number; PyObject* EnumTypeWrapper_class; static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; @@ -128,19 +126,6 @@ static bool AddFieldNumberToClass( // Finalize the creation of the Message class. static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { - // If there are extension_ranges, the message is "extendable", and extension - // classes will register themselves in this class. - if (descriptor->extension_range_count() > 0) { - ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) { - return -1; - } - ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) { - return -1; - } - } - // For each field set: cls.<field>_FIELD_NUMBER = <number> for (int i = 0; i < descriptor->field_count(); ++i) { if (!AddFieldNumberToClass(cls, descriptor->field(i))) { @@ -357,6 +342,61 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { #endif // PY_MAJOR_VERSION >= 3 } +// The _extensions_by_name dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindAllExtensions() +static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(), + extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +// The _extensions_by_number dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() +static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number())); + if (number == NULL) { + return NULL; + } + if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +static PyGetSetDef Getters[] = { + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, + {NULL} +}; + } // namespace message_meta PyTypeObject CMessageClass_Type = { @@ -389,7 +429,7 @@ PyTypeObject CMessageClass_Type = { 0, // tp_iternext 0, // tp_methods 0, // tp_members - 0, // tp_getset + message_meta::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -525,23 +565,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // --------------------------------------------------------------------- -// Constants used for integer type range checking. -PyObject* kPythonZero; -PyObject* kint32min_py; -PyObject* kint32max_py; -PyObject* kuint32max_py; -PyObject* kint64min_py; -PyObject* kint64max_py; -PyObject* kuint64max_py; - PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; -// Constant PyString values used for GetAttr/GetItem. -static PyObject* k_cdescriptor; -static PyObject* kfull_name; - /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { PyObject* repr = PyObject_Repr(arg); @@ -555,68 +582,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) { } } -template<class T> -bool CheckAndGetInteger( - PyObject* arg, T* value, PyObject* min, PyObject* max) { - bool is_long = PyLong_Check(arg); -#if PY_MAJOR_VERSION < 3 - if (!PyInt_Check(arg) && !is_long) { - FormatTypeError(arg, "int, long"); - return false; +void OutOfRangeError(PyObject* arg) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); } - if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) { -#else - if (!is_long) { - FormatTypeError(arg, "int"); +} + +template<class RangeType, class ValueType> +bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { + if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + // Replace it with the same ValueError as pure python protos instead of + // the default one. + PyErr_Clear(); + OutOfRangeError(arg); + } // Otherwise propagate existing error. return false; } - if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || - PyObject_RichCompareBool(max, arg, Py_GE) != 1) { -#endif - if (!PyErr_Occurred()) { - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); - } - } + if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) { + OutOfRangeError(arg); return false; } + return true; +} + +template<class T> +bool CheckAndGetInteger(PyObject* arg, T* value) { + // The fast path. #if PY_MAJOR_VERSION < 3 - if (!is_long) { - *value = static_cast<T>(PyInt_AsLong(arg)); - } else // NOLINT + // For the typical case, offer a fast path. + if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) { + long int_result = PyInt_AsLong(arg); + if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) { + *value = static_cast<T>(int_result); + return true; + } else { + OutOfRangeError(arg); + return false; + } + } #endif - { - if (min == kPythonZero) { - *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg)); + // This effectively defines an integer as "an object that can be cast as + // an integer and can be used as an ordinal number". + // This definition includes everything that implements numbers.Integral + // and shouldn't cast the net too wide. + if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) { + FormatTypeError(arg, "int, long"); + return false; + } + + // Now we have an integral number so we can safely use PyLong_ functions. + // We need to treat the signed and unsigned cases differently in case arg is + // holding a value above the maximum for signed longs. + if (std::numeric_limits<T>::min() == 0) { + // Unsigned case. + unsigned PY_LONG_LONG ulong_result; + if (PyLong_Check(arg)) { + ulong_result = PyLong_AsUnsignedLongLong(arg); + } else { + // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very + // picky about the exact type. + PyObject* casted = PyNumber_Long(arg); + if GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } + ulong_result = PyLong_AsUnsignedLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg, + ulong_result)) { + *value = static_cast<T>(ulong_result); + } else { + return false; + } + } else { + // Signed case. + PY_LONG_LONG long_result; + PyNumberMethods *nb; + if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) { + // PyLong_AsLongLong requires it to be a long or to have an __int__() + // method. + long_result = PyLong_AsLongLong(arg); } else { - *value = static_cast<T>(PyLong_AsLongLong(arg)); + // Valid subclasses of numbers.Integral should have a __long__() method + // so fall back to that. + PyObject* casted = PyNumber_Long(arg); + if GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } + long_result = PyLong_AsLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) { + *value = static_cast<T>(long_result); + } else { + return false; } } + return true; } // These are referenced by repeated_scalar_container, and must // be explicitly instantiated. -template bool CheckAndGetInteger<int32>( - PyObject*, int32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<int64>( - PyObject*, int64*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint32>( - PyObject*, uint32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint64>( - PyObject*, uint64*, PyObject*, PyObject*); +template bool CheckAndGetInteger<int32>(PyObject*, int32*); +template bool CheckAndGetInteger<int64>(PyObject*, int64*); +template bool CheckAndGetInteger<uint32>(PyObject*, uint32*); +template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); bool CheckAndGetDouble(PyObject* arg, double* value) { - if (!PyInt_Check(arg) && !PyLong_Check(arg) && - !PyFloat_Check(arg)) { + *value = PyFloat_AsDouble(arg); + if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) { FormatTypeError(arg, "int, long, float"); return false; } - *value = PyFloat_AsDouble(arg); return true; } @@ -630,11 +715,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) { } bool CheckAndGetBool(PyObject* arg, bool* value) { - if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) { + long long_value = PyInt_AsLong(arg); + if (long_value == -1 && PyErr_Occurred()) { FormatTypeError(arg, "int, long, bool"); return false; } - *value = static_cast<bool>(PyInt_AsLong(arg)); + *value = static_cast<bool>(long_value); + return true; } @@ -966,20 +1053,7 @@ int InternalDeleteRepeatedField( int min, max; length = reflection->FieldSize(*message, field_descriptor); - if (PyInt_Check(slice) || PyLong_Check(slice)) { - from = to = PyLong_AsLong(slice); - if (from < 0) { - from = to = length + from; - } - step = 1; - min = max = from; - - // Range check. - if (from < 0 || from >= length) { - PyErr_Format(PyExc_IndexError, "list assignment index out of range"); - return -1; - } - } else if (PySlice_Check(slice)) { + if (PySlice_Check(slice)) { from = to = step = slice_length = 0; PySlice_GetIndicesEx( #if PY_MAJOR_VERSION < 3 @@ -996,8 +1070,23 @@ int InternalDeleteRepeatedField( max = from; } } else { - PyErr_SetString(PyExc_TypeError, "list indices must be integers"); - return -1; + from = to = PyLong_AsLong(slice); + if (from == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (from < 0) { + from = to = length + from; + } + step = 1; + min = max = from; + + // Range check. + if (from < 0 || from >= length) { + PyErr_Format(PyExc_IndexError, "list assignment index out of range"); + return -1; + } } Py_ssize_t i = from; @@ -1958,99 +2047,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { return PyLong_FromLong(self->message->ByteSize()); } -static PyObject* RegisterExtension(PyObject* cls, - PyObject* extension_handle) { +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { const FieldDescriptor* descriptor = GetExtensionDescriptor(extension_handle); if (descriptor == NULL) { return NULL; } - - ScopedPyObjectPtr extensions_by_name( - PyObject_GetAttr(cls, k_extensions_by_name)); - if (extensions_by_name == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class"); + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", + cls->ob_type->tp_name); return NULL; } - ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name)); - if (full_name == NULL) { + CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls); + if (message_class == NULL) { return NULL; } - // If the extension was already registered, check that it is the same. - PyObject* existing_extension = - PyDict_GetItem(extensions_by_name.get(), full_name.get()); - if (existing_extension != NULL) { - const FieldDescriptor* existing_extension_descriptor = - GetExtensionDescriptor(existing_extension); - if (existing_extension_descriptor != descriptor) { - PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); - return NULL; - } - // Nothing else to do. - Py_RETURN_NONE; - } - - if (PyDict_SetItem(extensions_by_name.get(), full_name.get(), - extension_handle) < 0) { - return NULL; - } - - // Also store a mapping from extension number to implementing class. - ScopedPyObjectPtr extensions_by_number( - PyObject_GetAttr(cls, k_extensions_by_number)); - if (extensions_by_number == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); - return NULL; - } - - ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); - if (number == NULL) { - return NULL; - } - - // If the extension was already registered by number, check that it is the - // same. - existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get()); - if (existing_extension != NULL) { - const FieldDescriptor* existing_extension_descriptor = - GetExtensionDescriptor(existing_extension); - if (existing_extension_descriptor != descriptor) { - const Descriptor* msg_desc = GetMessageDescriptor( - reinterpret_cast<PyTypeObject*>(cls)); - PyErr_Format( - PyExc_ValueError, - "Extensions \"%s\" and \"%s\" both try to extend message type " - "\"%s\" with field number %ld.", - existing_extension_descriptor->full_name().c_str(), - descriptor->full_name().c_str(), - msg_desc->full_name().c_str(), - PyInt_AsLong(number.get())); - return NULL; - } - // Nothing else to do. - Py_RETURN_NONE; - } - if (PyDict_SetItem(extensions_by_number.get(), number.get(), - extension_handle) < 0) { + const FieldDescriptor* existing_extension = + message_class->py_message_factory->pool->pool->FindExtensionByNumber( + descriptor->containing_type(), descriptor->number()); + if (existing_extension != NULL && existing_extension != descriptor) { + PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); return NULL; } - - // Check if it's a message set - if (descriptor->is_extension() && - descriptor->containing_type()->options().message_set_wire_format() && - descriptor->type() == FieldDescriptor::TYPE_MESSAGE && - descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { - ScopedPyObjectPtr message_name(PyString_FromStringAndSize( - descriptor->message_type()->full_name().c_str(), - descriptor->message_type()->full_name().size())); - if (message_name == NULL) { - return NULL; - } - PyDict_SetItem(extensions_by_name.get(), message_name.get(), - extension_handle); - } - Py_RETURN_NONE; } @@ -2087,7 +2106,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { static PyObject* GetExtensionDict(CMessage* self, void *closure); static PyObject* ListFields(CMessage* self) { - vector<const FieldDescriptor*> fields; + std::vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); // Normally, the list will be exactly the size of the fields. @@ -2178,7 +2197,7 @@ static PyObject* DiscardUnknownFields(CMessage* self) { PyObject* FindInitializationErrors(CMessage* self) { Message* message = self->message; - vector<string> errors; + std::vector<string> errors; message->FindInitializationErrors(&errors); PyObject* error_list = PyList_New(errors.size()); @@ -2570,11 +2589,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) { return NULL; } +static PyObject* GetExtensionsByName(CMessage *self, void *closure) { + return message_meta::GetExtensionsByName( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + +static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) { + return message_meta::GetExtensionsByNumber( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + static PyGetSetDef Getters[] = { {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, {NULL} }; + static PyMethodDef Methods[] = { { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -2835,19 +2867,7 @@ void InitGlobals() { // TODO(gps): Check all return values in this function for NULL and propagate // the error (MemoryError) on up to result in an import failure. These should // also be freed and reset to NULL during finalization. - kPythonZero = PyInt_FromLong(0); - kint32min_py = PyInt_FromLong(kint32min); - kint32max_py = PyInt_FromLong(kint32max); - kuint32max_py = PyLong_FromLongLong(kuint32max); - kint64min_py = PyLong_FromLongLong(kint64min); - kint64max_py = PyLong_FromLongLong(kint64max); - kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); - kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); - k_cdescriptor = PyString_FromString("_cdescriptor"); - kfull_name = PyString_FromString("full_name"); - k_extensions_by_name = PyString_FromString("_extensions_by_name"); - k_extensions_by_number = PyString_FromString("_extensions_by_number"); PyObject *dummy_obj = PySet_New(NULL); kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); @@ -2887,25 +2907,6 @@ bool InitProto2MessageModule(PyObject *m) { // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set // it here as well to document that subclasses need to set it. PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); - // Subclasses with message extensions will override _extensions_by_name and - // _extensions_by_number with fresh mutable dictionaries in AddDescriptors. - // All other classes can share this same immutable mapping. - ScopedPyObjectPtr empty_dict(PyDict_New()); - if (empty_dict == NULL) { - return false; - } - ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get())); - if (immutable_dict == NULL) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_name, immutable_dict.get()) < 0) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_number, immutable_dict.get()) < 0) { - return false; - } PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 1550724c..ce80497e 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -117,6 +117,7 @@ typedef struct CMessage { PyObject* weakreflist; } CMessage; +extern PyTypeObject CMessageClass_Type; extern PyTypeObject CMessage_Type; @@ -235,6 +236,10 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs); PyObject* MergeFrom(CMessage* self, PyObject* arg); +// This method does not do anything beyond checking that no other extension +// has been registered with the same field number on this class. +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle); + // Retrieves an attribute named 'name' from CMessage 'self'. Returns // the attribute value on success, or NULL on failure. // @@ -275,25 +280,25 @@ PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg); #define GOOGLE_CHECK_GET_INT32(arg, value, err) \ int32 value; \ - if (!CheckAndGetInteger(arg, &value, kint32min_py, kint32max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } #define GOOGLE_CHECK_GET_INT64(arg, value, err) \ int64 value; \ - if (!CheckAndGetInteger(arg, &value, kint64min_py, kint64max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } #define GOOGLE_CHECK_GET_UINT32(arg, value, err) \ uint32 value; \ - if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint32max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } #define GOOGLE_CHECK_GET_UINT64(arg, value, err) \ uint64 value; \ - if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint64max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } @@ -316,20 +321,11 @@ PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg); } -extern PyObject* kPythonZero; -extern PyObject* kint32min_py; -extern PyObject* kint32max_py; -extern PyObject* kuint32max_py; -extern PyObject* kint64min_py; -extern PyObject* kint64max_py; -extern PyObject* kuint64max_py; - #define FULL_MODULE_NAME "google.protobuf.pyext._message" void FormatTypeError(PyObject* arg, char* expected_types); template<class T> -bool CheckAndGetInteger( - PyObject* arg, T* value, PyObject* min, PyObject* max); +bool CheckAndGetInteger(PyObject* arg, T* value); bool CheckAndGetDouble(PyObject* arg, double* value); bool CheckAndGetFloat(PyObject* arg, float* value); bool CheckAndGetBool(PyObject* arg, bool* value); diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc index 2ad89022..e0b45bf2 100644 --- a/python/google/protobuf/pyext/message_factory.cc +++ b/python/google/protobuf/pyext/message_factory.cc @@ -130,6 +130,72 @@ int RegisterMessageClass(PyMessageFactory* self, return 0; } +CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, + const Descriptor* descriptor) { + // This is the same implementation as MessageFactory.GetPrototype(). + ScopedPyObjectPtr py_descriptor( + PyMessageDescriptor_FromDescriptor(descriptor)); + if (py_descriptor == NULL) { + return NULL; + } + // Do not create a MessageClass that already exists. + hash_map<const Descriptor*, CMessageClass*>::iterator it = + self->classes_by_descriptor->find(descriptor); + if (it != self->classes_by_descriptor->end()) { + Py_INCREF(it->second); + return it->second; + } + // Create a new message class. + ScopedPyObjectPtr args(Py_BuildValue( + "s(){sOsOsO}", descriptor->name().c_str(), + "DESCRIPTOR", py_descriptor.get(), + "__module__", Py_None, + "message_factory", self)); + if (args == NULL) { + return NULL; + } + ScopedPyObjectPtr message_class(PyObject_CallObject( + reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get())); + if (message_class == NULL) { + return NULL; + } + // Create messages class for the messages used by the fields, and registers + // all extensions for these messages during the recursion. + for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) { + const Descriptor* sub_descriptor = + descriptor->field(field_idx)->message_type(); + // It is NULL if the field type is not a message. + if (sub_descriptor != NULL) { + CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor); + if (result == NULL) { + return NULL; + } + Py_DECREF(result); + } + } + + // Register extensions defined in this message. + for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) { + const FieldDescriptor* extension = descriptor->extension(ext_idx); + ScopedPyObjectPtr py_extended_class( + GetOrCreateMessageClass(self, extension->containing_type()) + ->AsPyObject()); + if (py_extended_class == NULL) { + return NULL; + } + ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension)); + if (py_extension == NULL) { + return NULL; + } + ScopedPyObjectPtr result(cmessage::RegisterExtension( + py_extended_class.get(), py_extension.get())); + if (result == NULL) { + return NULL; + } + } + return reinterpret_cast<CMessageClass*>(message_class.release()); +} + // Retrieve the message class added to our database. CMessageClass* GetMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor) { diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h index 07cccbfb..36092f7e 100644 --- a/python/google/protobuf/pyext/message_factory.h +++ b/python/google/protobuf/pyext/message_factory.h @@ -82,14 +82,14 @@ PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool); int RegisterMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor, CMessageClass* message_class); - -// Retrieves the Python class registered with the given message descriptor. -// -// Returns a *borrowed* reference if found, otherwise returns NULL with an -// exception set. +// Retrieves the Python class registered with the given message descriptor, or +// fail with a TypeError. Returns a *borrowed* reference. CMessageClass* GetMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor); - +// Retrieves the Python class registered with the given message descriptor. +// The class is created if not done yet. Returns a *new* reference. +CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor); } // namespace message_factory // Initialize objects used by this module. diff --git a/python/google/protobuf/pyext/safe_numerics.h b/python/google/protobuf/pyext/safe_numerics.h new file mode 100644 index 00000000..639ba2c8 --- /dev/null +++ b/python/google/protobuf/pyext/safe_numerics.h @@ -0,0 +1,164 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ +// Copied from chromium with only changes to the namespace. + +#include <limits> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> + +namespace google { +namespace protobuf { +namespace python { + +template <bool SameSize, bool DestLarger, + bool DestIsSigned, bool SourceIsSigned> +struct IsValidNumericCastImpl; + +#define BASE_NUMERIC_CAST_CASE_SPECIALIZATION(A, B, C, D, Code) \ +template <> struct IsValidNumericCastImpl<A, B, C, D> { \ + template <class Source, class DestBounds> static inline bool Test( \ + Source source, DestBounds min, DestBounds max) { \ + return Code; \ + } \ +} + +#define BASE_NUMERIC_CAST_CASE_SAME_SIZE(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + true, true, DestSigned, SourceSigned, Code); \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + true, false, DestSigned, SourceSigned, Code) + +#define BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + false, false, DestSigned, SourceSigned, Code); \ + +#define BASE_NUMERIC_CAST_CASE_DEST_LARGER(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + false, true, DestSigned, SourceSigned, Code); \ + +// The three top level cases are: +// - Same size +// - Source larger +// - Dest larger +// And for each of those three cases, we handle the 4 different possibilities +// of signed and unsigned. This gives 12 cases to handle, which we enumerate +// below. +// +// The last argument in each of the macros is the actual comparison code. It +// has three arguments available, source (the value), and min/max which are +// the ranges of the destination. + + +// These are the cases where both types have the same size. + +// Both signed. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, true, true); +// Both unsigned. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, false, true); +// Dest unsigned, Source signed. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, true, source >= 0); +// Dest signed, Source unsigned. +// This cast is OK because Dest's max must be less than Source's. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, false, + source <= static_cast<Source>(max)); + + +// These are the cases where Source is larger. + +// Both unsigned. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, false, source <= max); +// Both signed. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, true, + source >= min && source <= max); +// Dest is unsigned, Source is signed. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, true, + source >= 0 && source <= max); +// Dest is signed, Source is unsigned. +// This cast is OK because Dest's max must be less than Source's. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, false, + source <= static_cast<Source>(max)); + + +// These are the cases where Dest is larger. + +// Both unsigned. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, false, true); +// Both signed. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, true, true); +// Dest is unsigned, Source is signed. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, true, source >= 0); +// Dest is signed, Source is unsigned. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, false, true); + +#undef BASE_NUMERIC_CAST_CASE_SPECIALIZATION +#undef BASE_NUMERIC_CAST_CASE_SAME_SIZE +#undef BASE_NUMERIC_CAST_CASE_SOURCE_LARGER +#undef BASE_NUMERIC_CAST_CASE_DEST_LARGER + + +// The main test for whether the conversion will under or overflow. +template <class Dest, class Source> +inline bool IsValidNumericCast(Source source) { + typedef std::numeric_limits<Source> SourceLimits; + typedef std::numeric_limits<Dest> DestLimits; + GOOGLE_COMPILE_ASSERT(SourceLimits::is_specialized, argument_must_be_numeric); + GOOGLE_COMPILE_ASSERT(SourceLimits::is_integer, argument_must_be_integral); + GOOGLE_COMPILE_ASSERT(DestLimits::is_specialized, result_must_be_numeric); + GOOGLE_COMPILE_ASSERT(DestLimits::is_integer, result_must_be_integral); + + return IsValidNumericCastImpl< + sizeof(Dest) == sizeof(Source), + (sizeof(Dest) > sizeof(Source)), + DestLimits::is_signed, + SourceLimits::is_signed>::Test( + source, + DestLimits::min(), + DestLimits::max()); +} + +// checked_numeric_cast<> is analogous to static_cast<> for numeric types, +// except that it CHECKs that the specified numeric conversion will not +// overflow or underflow. Floating point arguments are not currently allowed +// (this is COMPILE_ASSERTd), though this could be supported if necessary. +template <class Dest, class Source> +inline Dest checked_numeric_cast(Source source) { + GOOGLE_CHECK(IsValidNumericCast<Dest>(source)); + return static_cast<Dest>(source); +} + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ |