aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/pyext/message.cc
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/pyext/message.cc')
-rw-r--r--python/google/protobuf/pyext/message.cc1033
1 files changed, 516 insertions, 517 deletions
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 60ec9c1b..53736b9c 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -35,9 +35,6 @@
#include <map>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <string>
#include <vector>
#include <structmember.h> // A Python header file.
@@ -52,6 +49,7 @@
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
@@ -63,11 +61,11 @@
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/map_container.h>
+#include <google/protobuf/pyext/message_factory.h>
+#include <google/protobuf/pyext/safe_numerics.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-#include <google/protobuf/stubs/strutil.h>
#if PY_MAJOR_VERSION >= 3
- #define PyInt_Check PyLong_Check
#define PyInt_AsLong PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
#define PyInt_FromSize_t PyLong_FromSize_t
@@ -91,42 +89,26 @@ namespace protobuf {
namespace python {
static PyObject* kDESCRIPTOR;
-static PyObject* k_extensions_by_name;
-static PyObject* k_extensions_by_number;
PyObject* EnumTypeWrapper_class;
static PyObject* PythonMessage_class;
static PyObject* kEmptyWeakref;
static PyObject* WKT_classes = NULL;
-// Defines the Metaclass of all Message classes.
-// It allows us to cache some C++ pointers in the class object itself, they are
-// faster to extract than from the type's dictionary.
-
-struct PyMessageMeta {
- // This is how CPython subclasses C structures: the base structure must be
- // the first member of the object.
- PyHeapTypeObject super;
-
- // C++ descriptor of this message.
- const Descriptor* message_descriptor;
-
- // Owned reference, used to keep the pointer above alive.
- PyObject* py_message_descriptor;
-
- // The Python DescriptorPool used to create the class. It is needed to resolve
- // fields descriptors, including extensions fields; its C++ MessageFactory is
- // used to instantiate submessages.
- // This can be different from DESCRIPTOR.file.pool, in the case of a custom
- // DescriptorPool which defines new extensions.
- // We own the reference, because it's important to keep the descriptors and
- // factory alive.
- PyDescriptorPool* py_descriptor_pool;
-};
-
namespace message_meta {
static int InsertEmptyWeakref(PyTypeObject* base);
+namespace {
+// Copied oveer from internal 'google/protobuf/stubs/strutil.h'.
+inline void UpperString(string * s) {
+ string::iterator end = s->end();
+ for (string::iterator i = s->begin(); i != end; ++i) {
+ // toupper() changes based on locale. We don't want this!
+ if ('a' <= *i && *i <= 'z') *i += 'A' - 'a';
+ }
+}
+}
+
// Add the number of a field descriptor to the containing message class.
// Equivalent to:
// _cls.<field>_FIELD_NUMBER = <number>
@@ -152,19 +134,6 @@ static bool AddFieldNumberToClass(
// Finalize the creation of the Message class.
static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
- // If there are extension_ranges, the message is "extendable", and extension
- // classes will register themselves in this class.
- if (descriptor->extension_range_count() > 0) {
- ScopedPyObjectPtr by_name(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) {
- return -1;
- }
- ScopedPyObjectPtr by_number(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) {
- return -1;
- }
- }
-
// For each field set: cls.<field>_FIELD_NUMBER = <number>
for (int i = 0; i < descriptor->field_count(); ++i) {
if (!AddFieldNumberToClass(cls, descriptor->field(i))) {
@@ -173,10 +142,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
}
// For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
- //
- // The enum descriptor we get from
- // <messagedescriptor>.enum_types_by_name[name]
- // which was built previously.
for (int i = 0; i < descriptor->enum_type_count(); ++i) {
const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
ScopedPyObjectPtr enum_type(
@@ -273,6 +238,12 @@ static PyObject* New(PyTypeObject* type,
return NULL;
}
+ // Messages have no __dict__
+ ScopedPyObjectPtr slots(PyTuple_New(0));
+ if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
+ return NULL;
+ }
+
// Build the arguments to the base metaclass.
// We change the __bases__ classes.
ScopedPyObjectPtr new_args;
@@ -309,7 +280,7 @@ static PyObject* New(PyTypeObject* type,
if (result == NULL) {
return NULL;
}
- PyMessageMeta* newtype = reinterpret_cast<PyMessageMeta*>(result.get());
+ CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
// Insert the empty weakref into the base classes.
if (InsertEmptyWeakref(
@@ -329,16 +300,19 @@ static PyObject* New(PyTypeObject* type,
newtype->message_descriptor = descriptor;
// TODO(amauryfa): Don't always use the canonical pool of the descriptor,
// use the MessageFactory optionally passed in the class dict.
- newtype->py_descriptor_pool = GetDescriptorPool_FromPool(
- descriptor->file()->pool());
- if (newtype->py_descriptor_pool == NULL) {
+ PyDescriptorPool* py_descriptor_pool =
+ GetDescriptorPool_FromPool(descriptor->file()->pool());
+ if (py_descriptor_pool == NULL) {
return NULL;
}
- Py_INCREF(newtype->py_descriptor_pool);
+ newtype->py_message_factory = py_descriptor_pool->py_message_factory;
+ Py_INCREF(newtype->py_message_factory);
- // Add the message to the DescriptorPool.
- if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool,
- descriptor, result.get()) < 0) {
+ // Register the message in the MessageFactory.
+ // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the
+ // MessageFactory is fully implemented in C++.
+ if (message_factory::RegisterMessageClass(newtype->py_message_factory,
+ descriptor, newtype) < 0) {
return NULL;
}
@@ -349,9 +323,9 @@ static PyObject* New(PyTypeObject* type,
return result.release();
}
-static void Dealloc(PyMessageMeta *self) {
- Py_DECREF(self->py_message_descriptor);
- Py_DECREF(self->py_descriptor_pool);
+static void Dealloc(CMessageClass *self) {
+ Py_XDECREF(self->py_message_descriptor);
+ Py_XDECREF(self->py_message_factory);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -376,12 +350,67 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) {
#endif // PY_MAJOR_VERSION >= 3
}
+// The _extensions_by_name dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
+static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
+ extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+// The _extensions_by_number dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
+static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
+ if (number == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+static PyGetSetDef Getters[] = {
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
+ {NULL}
+};
+
} // namespace message_meta
-PyTypeObject PyMessageMeta_Type = {
+PyTypeObject CMessageClass_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
FULL_MODULE_NAME ".MessageMeta", // tp_name
- sizeof(PyMessageMeta), // tp_basicsize
+ sizeof(CMessageClass), // tp_basicsize
0, // tp_itemsize
(destructor)message_meta::Dealloc, // tp_dealloc
0, // tp_print
@@ -408,7 +437,7 @@ PyTypeObject PyMessageMeta_Type = {
0, // tp_iternext
0, // tp_methods
0, // tp_members
- 0, // tp_getset
+ message_meta::Getters, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
@@ -419,16 +448,16 @@ PyTypeObject PyMessageMeta_Type = {
message_meta::New, // tp_new
};
-static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) {
- if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) {
+static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
+ if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
return NULL;
}
- return reinterpret_cast<PyMessageMeta*>(cls);
+ return reinterpret_cast<CMessageClass*>(cls);
}
static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
- PyMessageMeta* type = CheckMessageClass(cls);
+ CMessageClass* type = CheckMessageClass(cls);
if (type == NULL) {
return NULL;
}
@@ -544,23 +573,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) {
// ---------------------------------------------------------------------
-// Constants used for integer type range checking.
-PyObject* kPythonZero;
-PyObject* kint32min_py;
-PyObject* kint32max_py;
-PyObject* kuint32max_py;
-PyObject* kint64min_py;
-PyObject* kint64max_py;
-PyObject* kuint64max_py;
-
PyObject* EncodeError_class;
PyObject* DecodeError_class;
PyObject* PickleError_class;
-// Constant PyString values used for GetAttr/GetItem.
-static PyObject* k_cdescriptor;
-static PyObject* kfull_name;
-
/* Is 64bit */
void FormatTypeError(PyObject* arg, char* expected_types) {
PyObject* repr = PyObject_Repr(arg);
@@ -574,68 +590,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) {
}
}
-template<class T>
-bool CheckAndGetInteger(
- PyObject* arg, T* value, PyObject* min, PyObject* max) {
- bool is_long = PyLong_Check(arg);
-#if PY_MAJOR_VERSION < 3
- if (!PyInt_Check(arg) && !is_long) {
- FormatTypeError(arg, "int, long");
- return false;
+void OutOfRangeError(PyObject* arg) {
+ PyObject *s = PyObject_Str(arg);
+ if (s) {
+ PyErr_Format(PyExc_ValueError,
+ "Value out of range: %s",
+ PyString_AsString(s));
+ Py_DECREF(s);
}
- if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) {
-#else
- if (!is_long) {
- FormatTypeError(arg, "int");
+}
+
+template<class RangeType, class ValueType>
+bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
+ if (GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred())) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ // Replace it with the same ValueError as pure python protos instead of
+ // the default one.
+ PyErr_Clear();
+ OutOfRangeError(arg);
+ } // Otherwise propagate existing error.
return false;
- }
- if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 ||
- PyObject_RichCompareBool(max, arg, Py_GE) != 1) {
-#endif
- if (!PyErr_Occurred()) {
- PyObject *s = PyObject_Str(arg);
- if (s) {
- PyErr_Format(PyExc_ValueError,
- "Value out of range: %s",
- PyString_AsString(s));
- Py_DECREF(s);
- }
}
- return false;
- }
+ if (GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) {
+ OutOfRangeError(arg);
+ return false;
+ }
+ return true;
+}
+
+template<class T>
+bool CheckAndGetInteger(PyObject* arg, T* value) {
+ // The fast path.
#if PY_MAJOR_VERSION < 3
- if (!is_long) {
- *value = static_cast<T>(PyInt_AsLong(arg));
- } else // NOLINT
+ // For the typical case, offer a fast path.
+ if (GOOGLE_PREDICT_TRUE(PyInt_Check(arg))) {
+ long int_result = PyInt_AsLong(arg);
+ if (GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) {
+ *value = static_cast<T>(int_result);
+ return true;
+ } else {
+ OutOfRangeError(arg);
+ return false;
+ }
+ }
#endif
- {
- if (min == kPythonZero) {
- *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg));
+ // This effectively defines an integer as "an object that can be cast as
+ // an integer and can be used as an ordinal number".
+ // This definition includes everything that implements numbers.Integral
+ // and shouldn't cast the net too wide.
+ if (GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg))) {
+ FormatTypeError(arg, "int, long");
+ return false;
+ }
+
+ // Now we have an integral number so we can safely use PyLong_ functions.
+ // We need to treat the signed and unsigned cases differently in case arg is
+ // holding a value above the maximum for signed longs.
+ if (std::numeric_limits<T>::min() == 0) {
+ // Unsigned case.
+ unsigned PY_LONG_LONG ulong_result;
+ if (PyLong_Check(arg)) {
+ ulong_result = PyLong_AsUnsignedLongLong(arg);
} else {
- *value = static_cast<T>(PyLong_AsLongLong(arg));
+ // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
+ // picky about the exact type.
+ PyObject* casted = PyNumber_Long(arg);
+ if (GOOGLE_PREDICT_FALSE(casted == nullptr)) {
+ // Propagate existing error.
+ return false;
+ }
+ ulong_result = PyLong_AsUnsignedLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
+ ulong_result)) {
+ *value = static_cast<T>(ulong_result);
+ } else {
+ return false;
+ }
+ } else {
+ // Signed case.
+ PY_LONG_LONG long_result;
+ PyNumberMethods *nb;
+ if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
+ // PyLong_AsLongLong requires it to be a long or to have an __int__()
+ // method.
+ long_result = PyLong_AsLongLong(arg);
+ } else {
+ // Valid subclasses of numbers.Integral should have a __long__() method
+ // so fall back to that.
+ PyObject* casted = PyNumber_Long(arg);
+ if (GOOGLE_PREDICT_FALSE(casted == nullptr)) {
+ // Propagate existing error.
+ return false;
+ }
+ long_result = PyLong_AsLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
+ *value = static_cast<T>(long_result);
+ } else {
+ return false;
}
}
+
return true;
}
// These are referenced by repeated_scalar_container, and must
// be explicitly instantiated.
-template bool CheckAndGetInteger<int32>(
- PyObject*, int32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<int64>(
- PyObject*, int64*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint32>(
- PyObject*, uint32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint64>(
- PyObject*, uint64*, PyObject*, PyObject*);
+template bool CheckAndGetInteger<int32>(PyObject*, int32*);
+template bool CheckAndGetInteger<int64>(PyObject*, int64*);
+template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
+template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
bool CheckAndGetDouble(PyObject* arg, double* value) {
- if (!PyInt_Check(arg) && !PyLong_Check(arg) &&
- !PyFloat_Check(arg)) {
+ *value = PyFloat_AsDouble(arg);
+ if (GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) {
FormatTypeError(arg, "int, long, float");
return false;
- }
- *value = PyFloat_AsDouble(arg);
+ }
return true;
}
@@ -649,11 +723,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) {
}
bool CheckAndGetBool(PyObject* arg, bool* value) {
- if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) {
+ long long_value = PyInt_AsLong(arg);
+ if (long_value == -1 && PyErr_Occurred()) {
FormatTypeError(arg, "int, long, bool");
return false;
}
- *value = static_cast<bool>(PyInt_AsLong(arg));
+ *value = static_cast<bool>(long_value);
+
return true;
}
@@ -711,7 +787,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
encoded_string = arg; // Already encoded.
Py_INCREF(encoded_string);
} else {
- encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
+ encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
}
} else {
// In this case field type is "bytes".
@@ -751,7 +827,8 @@ bool CheckAndSetString(
return true;
}
-PyObject* ToStringObject(const FieldDescriptor* descriptor, string value) {
+PyObject* ToStringObject(const FieldDescriptor* descriptor,
+ const string& value) {
if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
return PyBytes_FromStringAndSize(value.c_str(), value.length());
}
@@ -781,15 +858,9 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
namespace cmessage {
-PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) {
- // No need to check the type: the type of instances of CMessage is always
- // an instance of PyMessageMeta. Let's prove it with a debug-only check.
+PyMessageFactory* GetFactoryForMessage(CMessage* message) {
GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type));
- return reinterpret_cast<PyMessageMeta*>(Py_TYPE(message))->py_descriptor_pool;
-}
-
-MessageFactory* GetFactoryForMessage(CMessage* message) {
- return GetDescriptorPoolForMessage(message)->message_factory;
+ return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
}
static int MaybeReleaseOverlappingOneofField(
@@ -842,7 +913,8 @@ static Message* GetMutableMessage(
return NULL;
}
return reflection->MutableMessage(
- parent_message, parent_field, GetFactoryForMessage(parent));
+ parent_message, parent_field,
+ GetFactoryForMessage(parent)->message_factory);
}
struct FixupMessageReference : public ChildVisitor {
@@ -990,28 +1062,17 @@ int InternalDeleteRepeatedField(
int min, max;
length = reflection->FieldSize(*message, field_descriptor);
- if (PyInt_Check(slice) || PyLong_Check(slice)) {
- from = to = PyLong_AsLong(slice);
- if (from < 0) {
- from = to = length + from;
- }
- step = 1;
- min = max = from;
-
- // Range check.
- if (from < 0 || from >= length) {
- PyErr_Format(PyExc_IndexError, "list assignment index out of range");
- return -1;
- }
- } else if (PySlice_Check(slice)) {
+ if (PySlice_Check(slice)) {
from = to = step = slice_length = 0;
- PySlice_GetIndicesEx(
#if PY_MAJOR_VERSION < 3
+ PySlice_GetIndicesEx(
reinterpret_cast<PySliceObject*>(slice),
+ length, &from, &to, &step, &slice_length);
#else
+ PySlice_GetIndicesEx(
slice,
-#endif
length, &from, &to, &step, &slice_length);
+#endif
if (from < to) {
min = from;
max = to - 1;
@@ -1020,8 +1081,23 @@ int InternalDeleteRepeatedField(
max = from;
}
} else {
- PyErr_SetString(PyExc_TypeError, "list indices must be integers");
- return -1;
+ from = to = PyLong_AsLong(slice);
+ if (from == -1 && PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers");
+ return -1;
+ }
+
+ if (from < 0) {
+ from = to = length + from;
+ }
+ step = 1;
+ min = max = from;
+
+ // Range check.
+ if (from < 0 || from >= length) {
+ PyErr_Format(PyExc_IndexError, "list assignment index out of range");
+ return -1;
+ }
}
Py_ssize_t i = from;
@@ -1070,7 +1146,12 @@ int InternalDeleteRepeatedField(
}
// Initializes fields of a message. Used in constructors.
-int InitAttributes(CMessage* self, PyObject* kwargs) {
+int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
+ if (args != NULL && PyTuple_Size(args) != 0) {
+ PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
+ return -1;
+ }
+
if (kwargs == NULL) {
return 0;
}
@@ -1090,8 +1171,12 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
PyString_AsString(name));
return -1;
}
+ if (value == Py_None) {
+ // field=None is the same as no field at all.
+ continue;
+ }
if (descriptor->is_map()) {
- ScopedPyObjectPtr map(GetAttr(self, name));
+ ScopedPyObjectPtr map(GetAttr(reinterpret_cast<PyObject*>(self), name));
const FieldDescriptor* value_descriptor =
descriptor->message_type()->FindFieldByName("value");
if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
@@ -1119,7 +1204,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
}
}
} else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
- ScopedPyObjectPtr container(GetAttr(self, name));
+ ScopedPyObjectPtr container(
+ GetAttr(reinterpret_cast<PyObject*>(self), name));
if (container == NULL) {
return -1;
}
@@ -1186,13 +1272,16 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
}
}
} else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- ScopedPyObjectPtr message(GetAttr(self, name));
+ ScopedPyObjectPtr message(
+ GetAttr(reinterpret_cast<PyObject*>(self), name));
if (message == NULL) {
return -1;
}
CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
if (PyDict_Check(value)) {
- if (InitAttributes(cmessage, value) < 0) {
+ // Make the message exist even if the dict is empty.
+ AssureWritable(cmessage);
+ if (InitAttributes(cmessage, NULL, value) < 0) {
return -1;
}
} else {
@@ -1209,8 +1298,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
return -1;
}
}
- if (SetAttr(self, name, (new_val.get() == NULL) ? value : new_val.get()) <
- 0) {
+ if (SetAttr(reinterpret_cast<PyObject*>(self), name,
+ (new_val.get() == NULL) ? value : new_val.get()) < 0) {
return -1;
}
}
@@ -1220,13 +1309,15 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
// Allocates an incomplete Python Message: the caller must fill self->message,
// self->owner and eventually self->parent.
-CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) {
+CMessage* NewEmptyMessage(CMessageClass* type) {
CMessage* self = reinterpret_cast<CMessage*>(
- PyType_GenericAlloc(reinterpret_cast<PyTypeObject*>(type), 0));
+ PyType_GenericAlloc(&type->super.ht_type, 0));
if (self == NULL) {
return NULL;
}
+ // Use "placement new" syntax to initialize the C++ object.
+ new (&self->owner) CMessage::OwnerRef(NULL);
self->message = NULL;
self->parent = NULL;
self->parent_field_descriptor = NULL;
@@ -1242,7 +1333,7 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) {
// Creates a new C++ message and takes ownership.
static PyObject* New(PyTypeObject* cls,
PyObject* unused_args, PyObject* unused_kwargs) {
- PyMessageMeta* type = CheckMessageClass(cls);
+ CMessageClass* type = CheckMessageClass(cls);
if (type == NULL) {
return NULL;
}
@@ -1251,15 +1342,14 @@ static PyObject* New(PyTypeObject* cls,
if (message_descriptor == NULL) {
return NULL;
}
- const Message* default_message = type->py_descriptor_pool->message_factory
+ const Message* default_message = type->py_message_factory->message_factory
->GetPrototype(message_descriptor);
if (default_message == NULL) {
PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
return NULL;
}
- CMessage* self = NewEmptyMessage(reinterpret_cast<PyObject*>(type),
- message_descriptor);
+ CMessage* self = NewEmptyMessage(type);
if (self == NULL) {
return NULL;
}
@@ -1271,12 +1361,7 @@ static PyObject* New(PyTypeObject* cls,
// The __init__ method of Message classes.
// It initializes fields from keywords passed to the constructor.
static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
- if (PyTuple_Size(args) != 0) {
- PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
- return -1;
- }
-
- return InitAttributes(self, kwargs);
+ return InitAttributes(self, args, kwargs);
}
// ---------------------------------------------------------------------
@@ -1318,6 +1403,9 @@ struct ClearWeakReferences : public ChildVisitor {
};
static void Dealloc(CMessage* self) {
+ if (self->weakreflist) {
+ PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
+ }
// Null out all weak references from children to this message.
GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences()));
if (self->extensions) {
@@ -1326,7 +1414,7 @@ static void Dealloc(CMessage* self) {
Py_CLEAR(self->extensions);
Py_CLEAR(self->composite_fields);
- self->owner.reset();
+ self->owner.~ThreadUnsafeSharedPtr<Message>();
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -1467,36 +1555,25 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
if (message->GetReflection()->HasField(*message, field_descriptor)) {
Py_RETURN_TRUE;
}
- if (!message->GetReflection()->SupportsUnknownEnumValues() &&
- field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
- // Special case: Python HasField() differs in semantics from C++
- // slightly: we return HasField('enum_field') == true if there is
- // an unknown enum value present. To implement this we have to
- // look in the UnknownFieldSet.
- const UnknownFieldSet& unknown_field_set =
- message->GetReflection()->GetUnknownFields(*message);
- for (int i = 0; i < unknown_field_set.field_count(); ++i) {
- if (unknown_field_set.field(i).number() == field_descriptor->number()) {
- Py_RETURN_TRUE;
- }
- }
- }
+
Py_RETURN_FALSE;
}
PyObject* ClearExtension(CMessage* self, PyObject* extension) {
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
+ return NULL;
+ }
if (self->extensions != NULL) {
- return extension_dict::ClearExtension(self->extensions, extension);
- } else {
- const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
- if (descriptor == NULL) {
- return NULL;
- }
- if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) {
- return NULL;
+ PyObject* value = PyDict_GetItem(self->extensions->values, extension);
+ if (value != NULL) {
+ if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) {
+ return NULL;
+ }
+ PyDict_DelItem(self->extensions->values, extension);
}
}
- Py_RETURN_NONE;
+ return ClearFieldByDescriptor(self, descriptor);
}
PyObject* HasExtension(CMessage* self, PyObject* extension) {
@@ -1539,9 +1616,10 @@ PyObject* HasExtension(CMessage* self, PyObject* extension) {
// * Clear the weak references from the released container to the
// parent.
-struct SetOwnerVisitor : public ChildVisitor {
+class SetOwnerVisitor : public ChildVisitor {
+ public:
// new_owner must outlive this object.
- explicit SetOwnerVisitor(const shared_ptr<Message>& new_owner)
+ explicit SetOwnerVisitor(const CMessage::OwnerRef& new_owner)
: new_owner_(new_owner) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
@@ -1565,11 +1643,11 @@ struct SetOwnerVisitor : public ChildVisitor {
}
private:
- const shared_ptr<Message>& new_owner_;
+ const CMessage::OwnerRef& new_owner_;
};
// Change the owner of this CMessage and all its children, recursively.
-int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
+int SetOwner(CMessage* self, const CMessage::OwnerRef& new_owner) {
self->owner = new_owner;
if (ForEachCompositeField(self, SetOwnerVisitor(new_owner)) == -1)
return -1;
@@ -1582,7 +1660,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
Message* ReleaseMessage(CMessage* self,
const Descriptor* descriptor,
const FieldDescriptor* field_descriptor) {
- MessageFactory* message_factory = GetFactoryForMessage(self);
+ MessageFactory* message_factory = GetFactoryForMessage(self)->message_factory;
Message* released_message = self->message->GetReflection()->ReleaseMessage(
self->message, field_descriptor, message_factory);
// ReleaseMessage will return NULL which differs from
@@ -1602,7 +1680,7 @@ int ReleaseSubMessage(CMessage* self,
const FieldDescriptor* field_descriptor,
CMessage* child_cmessage) {
// Release the Message
- shared_ptr<Message> released_message(ReleaseMessage(
+ CMessage::OwnerRef released_message(ReleaseMessage(
self, child_cmessage->message->GetDescriptor(), field_descriptor));
child_cmessage->message = released_message.get();
child_cmessage->owner.swap(released_message);
@@ -1619,23 +1697,20 @@ struct ReleaseChild : public ChildVisitor {
parent_(parent) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
- return repeated_composite_container::Release(
- reinterpret_cast<RepeatedCompositeContainer*>(container));
+ return repeated_composite_container::Release(container);
}
int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
- return repeated_scalar_container::Release(
- reinterpret_cast<RepeatedScalarContainer*>(container));
+ return repeated_scalar_container::Release(container);
}
int VisitMapContainer(MapContainer* container) {
- return reinterpret_cast<MapContainer*>(container)->Release();
+ return container->Release();
}
int VisitCMessage(CMessage* cmessage,
const FieldDescriptor* field_descriptor) {
- return ReleaseSubMessage(parent_, field_descriptor,
- reinterpret_cast<CMessage*>(cmessage));
+ return ReleaseSubMessage(parent_, field_descriptor, cmessage);
}
CMessage* parent_;
@@ -1653,12 +1728,13 @@ int InternalReleaseFieldByDescriptor(
PyObject* ClearFieldByDescriptor(
CMessage* self,
- const FieldDescriptor* descriptor) {
- if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
+ const FieldDescriptor* field_descriptor) {
+ if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
return NULL;
}
AssureWritable(self);
- self->message->GetReflection()->ClearField(self->message, descriptor);
+ Message* message = self->message;
+ message->GetReflection()->ClearField(message, field_descriptor);
Py_RETURN_NONE;
}
@@ -1694,27 +1770,17 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
arg = arg_in_oneof.get();
}
- PyObject* composite_field = self->composite_fields ?
- PyDict_GetItem(self->composite_fields, arg) : NULL;
-
- // Only release the field if there's a possibility that there are
- // references to it.
- if (composite_field != NULL) {
- if (InternalReleaseFieldByDescriptor(self, field_descriptor,
- composite_field) < 0) {
- return NULL;
+ // Release the field if it exists in the dict of composite fields.
+ if (self->composite_fields) {
+ PyObject* value = PyDict_GetItem(self->composite_fields, arg);
+ if (value != NULL) {
+ if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) {
+ return NULL;
+ }
+ PyDict_DelItem(self->composite_fields, arg);
}
- PyDict_DelItem(self->composite_fields, arg);
- }
- message->GetReflection()->ClearField(message, field_descriptor);
- if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM &&
- !message->GetReflection()->SupportsUnknownEnumValues()) {
- UnknownFieldSet* unknown_field_set =
- message->GetReflection()->MutableUnknownFields(message);
- unknown_field_set->DeleteByNumber(field_descriptor->number());
}
-
- Py_RETURN_NONE;
+ return ClearFieldByDescriptor(self, field_descriptor);
}
PyObject* Clear(CMessage* self) {
@@ -1739,8 +1805,25 @@ static string GetMessageName(CMessage* self) {
}
}
-static PyObject* SerializeToString(CMessage* self, PyObject* args) {
- if (!self->message->IsInitialized()) {
+static PyObject* InternalSerializeToString(
+ CMessage* self, PyObject* args, PyObject* kwargs,
+ bool require_initialized) {
+ // Parse the "deterministic" kwarg; defaults to False.
+ static char* kwlist[] = { "deterministic", 0 };
+ PyObject* deterministic_obj = Py_None;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
+ &deterministic_obj)) {
+ return NULL;
+ }
+ // Preemptively convert to a bool first, so we don't need to back out of
+ // allocating memory if this raises an exception.
+ // NOTE: This is unused later if deterministic == Py_None, but that's fine.
+ int deterministic = PyObject_IsTrue(deterministic_obj);
+ if (deterministic < 0) {
+ return NULL;
+ }
+
+ if (require_initialized && !self->message->IsInitialized()) {
ScopedPyObjectPtr errors(FindInitializationErrors(self));
if (errors == NULL) {
return NULL;
@@ -1778,24 +1861,36 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) {
GetMessageName(self).c_str(), PyString_AsString(joined.get()));
return NULL;
}
- int size = self->message->ByteSize();
- if (size <= 0) {
+
+ // Ok, arguments parsed and errors checked, now encode to a string
+ const size_t size = self->message->ByteSizeLong();
+ if (size == 0) {
return PyBytes_FromString("");
}
PyObject* result = PyBytes_FromStringAndSize(NULL, size);
if (result == NULL) {
return NULL;
}
- char* buffer = PyBytes_AS_STRING(result);
- self->message->SerializeWithCachedSizesToArray(
- reinterpret_cast<uint8*>(buffer));
+ io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
+ io::CodedOutputStream coded_out(&out);
+ if (deterministic_obj != Py_None) {
+ coded_out.SetSerializationDeterministic(deterministic);
+ }
+ self->message->SerializeWithCachedSizes(&coded_out);
+ GOOGLE_CHECK(!coded_out.HadError());
return result;
}
-static PyObject* SerializePartialToString(CMessage* self) {
- string contents;
- self->message->SerializePartialToString(&contents);
- return PyBytes_FromStringAndSize(contents.c_str(), contents.size());
+static PyObject* SerializeToString(
+ CMessage* self, PyObject* args, PyObject* kwargs) {
+ return InternalSerializeToString(self, args, kwargs,
+ /*require_initialized=*/true);
+}
+
+static PyObject* SerializePartialToString(
+ CMessage* self, PyObject* args, PyObject* kwargs) {
+ return InternalSerializeToString(self, args, kwargs,
+ /*require_initialized=*/false);
}
// Formats proto fields for ascii dumps using python formatting functions where
@@ -1851,8 +1946,12 @@ static PyObject* ToStr(CMessage* self) {
PyObject* MergeFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
- PyErr_SetString(PyExc_TypeError, "Must be a message");
+ if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s.",
+ self->message->GetDescriptor()->full_name().c_str(),
+ Py_TYPE(arg)->tp_name);
return NULL;
}
@@ -1860,8 +1959,8 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) {
if (other_message->message->GetDescriptor() !=
self->message->GetDescriptor()) {
PyErr_Format(PyExc_TypeError,
- "Tried to merge from a message with a different type. "
- "to: %s, from: %s",
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s.",
self->message->GetDescriptor()->full_name().c_str(),
other_message->message->GetDescriptor()->full_name().c_str());
return NULL;
@@ -1879,8 +1978,12 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) {
static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
- PyErr_SetString(PyExc_TypeError, "Must be a message");
+ if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "Parameter to CopyFrom() must be instance of same class: "
+ "expected %s got %s.",
+ self->message->GetDescriptor()->full_name().c_str(),
+ Py_TYPE(arg)->tp_name);
return NULL;
}
@@ -1893,8 +1996,8 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
if (other_message->message->GetDescriptor() !=
self->message->GetDescriptor()) {
PyErr_Format(PyExc_TypeError,
- "Tried to copy from a message with a different type. "
- "to: %s, from: %s",
+ "Parameter to CopyFrom() must be instance of same class: "
+ "expected %s got %s.",
self->message->GetDescriptor()->full_name().c_str(),
other_message->message->GetDescriptor()->full_name().c_str());
return NULL;
@@ -1911,6 +2014,34 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
Py_RETURN_NONE;
}
+// Protobuf has a 64MB limit built in, this variable will override this. Please
+// do not enable this unless you fully understand the implications: protobufs
+// must all be kept in memory at the same time, so if they grow too big you may
+// get OOM errors. The protobuf APIs do not provide any tools for processing
+// protobufs in chunks. If you have protos this big you should break them up if
+// it is at all convenient to do so.
+#ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
+static bool allow_oversize_protos = true;
+#else
+static bool allow_oversize_protos = false;
+#endif
+
+// Provide a method in the module to set allow_oversize_protos to a boolean
+// value. This method returns the newly value of allow_oversize_protos.
+PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
+ if (!arg || !PyBool_Check(arg)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Argument to SetAllowOversizeProtos must be boolean");
+ return NULL;
+ }
+ allow_oversize_protos = PyObject_IsTrue(arg);
+ if (allow_oversize_protos) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
const void* data;
Py_ssize_t data_length;
@@ -1921,19 +2052,18 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
AssureWritable(self);
io::CodedInputStream input(
reinterpret_cast<const uint8*>(data), data_length);
-#if PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
- // Protobuf has a 64MB limit built in, this code will override this. Please do
- // not enable this unless you fully understand the implications: protobufs
- // must all be kept in memory at the same time, so if they grow too big you
- // may get OOM errors. The protobuf APIs do not provide any tools for
- // processing protobufs in chunks. If you have protos this big you should
- // break them up if it is at all convenient to do so.
- input.SetTotalBytesLimit(INT_MAX, INT_MAX);
-#endif // PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
- PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
- input.SetExtensionRegistry(pool->pool, pool->message_factory);
+ if (allow_oversize_protos) {
+ input.SetTotalBytesLimit(INT_MAX, INT_MAX);
+ }
+ PyMessageFactory* factory = GetFactoryForMessage(self);
+ input.SetExtensionRegistry(factory->pool->pool, factory->message_factory);
bool success = self->message->MergePartialFromCodedStream(&input);
if (success) {
+ if (!input.ConsumedEntireMessage()) {
+ // TODO(jieluo): Raise error and return NULL instead.
+ // b/27494216
+ PyErr_Warn(NULL, "Unexpected end-group tag: Not all data was converted");
+ }
return PyInt_FromLong(input.CurrentPosition());
} else {
PyErr_Format(DecodeError_class, "Error parsing message");
@@ -1952,75 +2082,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) {
return PyLong_FromLong(self->message->ByteSize());
}
-static PyObject* RegisterExtension(PyObject* cls,
- PyObject* extension_handle) {
+PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
const FieldDescriptor* descriptor =
GetExtensionDescriptor(extension_handle);
if (descriptor == NULL) {
return NULL;
}
-
- ScopedPyObjectPtr extensions_by_name(
- PyObject_GetAttr(cls, k_extensions_by_name));
- if (extensions_by_name == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class");
+ if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
+ PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
+ cls->ob_type->tp_name);
return NULL;
}
- ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name));
- if (full_name == NULL) {
+ CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
+ if (message_class == NULL) {
return NULL;
}
-
// If the extension was already registered, check that it is the same.
- PyObject* existing_extension =
- PyDict_GetItem(extensions_by_name.get(), full_name.get());
- if (existing_extension != NULL) {
- const FieldDescriptor* existing_extension_descriptor =
- GetExtensionDescriptor(existing_extension);
- if (existing_extension_descriptor != descriptor) {
- PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
- return NULL;
- }
- // Nothing else to do.
- Py_RETURN_NONE;
- }
-
- if (PyDict_SetItem(extensions_by_name.get(), full_name.get(),
- extension_handle) < 0) {
- return NULL;
- }
-
- // Also store a mapping from extension number to implementing class.
- ScopedPyObjectPtr extensions_by_number(
- PyObject_GetAttr(cls, k_extensions_by_number));
- if (extensions_by_number == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class");
- return NULL;
- }
- ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number"));
- if (number == NULL) {
- return NULL;
- }
- if (PyDict_SetItem(extensions_by_number.get(), number.get(),
- extension_handle) < 0) {
+ const FieldDescriptor* existing_extension =
+ message_class->py_message_factory->pool->pool->FindExtensionByNumber(
+ descriptor->containing_type(), descriptor->number());
+ if (existing_extension != NULL && existing_extension != descriptor) {
+ PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
return NULL;
}
-
- // Check if it's a message set
- if (descriptor->is_extension() &&
- descriptor->containing_type()->options().message_set_wire_format() &&
- descriptor->type() == FieldDescriptor::TYPE_MESSAGE &&
- descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) {
- ScopedPyObjectPtr message_name(PyString_FromStringAndSize(
- descriptor->message_type()->full_name().c_str(),
- descriptor->message_type()->full_name().size()));
- if (message_name == NULL) {
- return NULL;
- }
- PyDict_SetItem(extensions_by_name.get(), message_name.get(),
- extension_handle);
- }
-
Py_RETURN_NONE;
}
@@ -2057,7 +2141,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
static PyObject* GetExtensionDict(CMessage* self, void *closure);
static PyObject* ListFields(CMessage* self) {
- vector<const FieldDescriptor*> fields;
+ std::vector<const FieldDescriptor*> fields;
self->message->GetReflection()->ListFields(*self->message, &fields);
// Normally, the list will be exactly the size of the fields.
@@ -2087,8 +2171,8 @@ static PyObject* ListFields(CMessage* self) {
// is no message class and we cannot retrieve the value.
// TODO(amauryfa): consider building the class on the fly!
if (fields[i]->message_type() != NULL &&
- cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self),
+ message_factory::GetMessageClass(
+ GetFactoryForMessage(self),
fields[i]->message_type()) == NULL) {
PyErr_Clear();
continue;
@@ -2121,7 +2205,8 @@ static PyObject* ListFields(CMessage* self) {
return NULL;
}
- PyObject* field_value = GetAttr(self, py_field_name.get());
+ PyObject* field_value =
+ GetAttr(reinterpret_cast<PyObject*>(self), py_field_name.get());
if (field_value == NULL) {
PyErr_SetObject(PyExc_ValueError, py_field_name.get());
return NULL;
@@ -2132,13 +2217,23 @@ static PyObject* ListFields(CMessage* self) {
PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
++actual_size;
}
- Py_SIZE(all_fields.get()) = actual_size;
+ if (static_cast<size_t>(actual_size) != fields.size() &&
+ (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) <
+ 0)) {
+ return NULL;
+ }
return all_fields.release();
}
+static PyObject* DiscardUnknownFields(CMessage* self) {
+ AssureWritable(self);
+ self->message->DiscardUnknownFields();
+ Py_RETURN_NONE;
+}
+
PyObject* FindInitializationErrors(CMessage* self) {
Message* message = self->message;
- vector<string> errors;
+ std::vector<string> errors;
message->FindInitializationErrors(&errors);
PyObject* error_list = PyList_New(errors.size());
@@ -2235,32 +2330,16 @@ PyObject* InternalGetScalar(const Message* message,
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
- string value = reflection->GetString(*message, field_descriptor);
+ string scratch;
+ const string& value =
+ reflection->GetStringReference(*message, field_descriptor, &scratch);
result = ToStringObject(field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_ENUM: {
- if (!message->GetReflection()->SupportsUnknownEnumValues() &&
- !message->GetReflection()->HasField(*message, field_descriptor)) {
- // Look for the value in the unknown fields.
- const UnknownFieldSet& unknown_field_set =
- message->GetReflection()->GetUnknownFields(*message);
- for (int i = 0; i < unknown_field_set.field_count(); ++i) {
- if (unknown_field_set.field(i).number() ==
- field_descriptor->number() &&
- unknown_field_set.field(i).type() ==
- google::protobuf::UnknownField::TYPE_VARINT) {
- result = PyInt_FromLong(unknown_field_set.field(i).varint());
- break;
- }
- }
- }
-
- if (result == NULL) {
- const EnumValueDescriptor* enum_value =
- message->GetReflection()->GetEnum(*message, field_descriptor);
- result = PyInt_FromLong(enum_value->number());
- }
+ const EnumValueDescriptor* enum_value =
+ message->GetReflection()->GetEnum(*message, field_descriptor);
+ result = PyInt_FromLong(enum_value->number());
break;
}
default:
@@ -2275,18 +2354,19 @@ PyObject* InternalGetScalar(const Message* message,
PyObject* InternalGetSubMessage(
CMessage* self, const FieldDescriptor* field_descriptor) {
const Reflection* reflection = self->message->GetReflection();
- PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
+ PyMessageFactory* factory = GetFactoryForMessage(self);
const Message& sub_message = reflection->GetMessage(
- *self->message, field_descriptor, pool->message_factory);
+ *self->message, field_descriptor, factory->message_factory);
- PyObject *message_class = cdescriptor_pool::GetMessageClass(
- pool, field_descriptor->message_type());
+ CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
+ factory, field_descriptor->message_type());
+ ScopedPyObjectPtr message_class_handler(
+ reinterpret_cast<PyObject*>(message_class));
if (message_class == NULL) {
return NULL;
}
- CMessage* cmsg = cmessage::NewEmptyMessage(message_class,
- sub_message.GetDescriptor());
+ CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
if (cmsg == NULL) {
return NULL;
}
@@ -2471,7 +2551,10 @@ PyObject* Reduce(CMessage* self) {
if (state == NULL) {
return NULL;
}
- ScopedPyObjectPtr serialized(SerializePartialToString(self));
+ string contents;
+ self->message->SerializePartialToString(&contents);
+ ScopedPyObjectPtr serialized(
+ PyBytes_FromStringAndSize(contents.c_str(), contents.size()));
if (serialized == NULL) {
return NULL;
}
@@ -2531,11 +2614,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) {
return NULL;
}
+static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByName(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
+static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByNumber(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
static PyGetSetDef Getters[] = {
{"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
{NULL}
};
+
static PyMethodDef Methods[] = {
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -2555,6 +2651,8 @@ static PyMethodDef Methods[] = {
"Clears a message field." },
{ "CopyFrom", (PyCFunction)CopyFrom, METH_O,
"Copies a protocol message into the current message." },
+ { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
+ "Discards the unknown fields." },
{ "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
METH_NOARGS,
"Finds unset required fields." },
@@ -2577,9 +2675,10 @@ static PyMethodDef Methods[] = {
{ "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
"Registers an extension with the current message." },
{ "SerializePartialToString", (PyCFunction)SerializePartialToString,
- METH_NOARGS,
+ METH_VARARGS | METH_KEYWORDS,
"Serializes the message to a string, even if it isn't initialized." },
- { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS,
+ { "SerializeToString", (PyCFunction)SerializeToString,
+ METH_VARARGS | METH_KEYWORDS,
"Serializes the message to a string, only for initialized messages." },
{ "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
"Sets the has bit of the given field in its parent message." },
@@ -2605,7 +2704,8 @@ static bool SetCompositeField(
return PyDict_SetItem(self->composite_fields, name, value) == 0;
}
-PyObject* GetAttr(CMessage* self, PyObject* name) {
+PyObject* GetAttr(PyObject* pself, PyObject* name) {
+ CMessage* self = reinterpret_cast<CMessage*>(pself);
PyObject* value = self->composite_fields ?
PyDict_GetItem(self->composite_fields, name) : NULL;
if (value != NULL) {
@@ -2624,8 +2724,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
const Descriptor* entry_type = field_descriptor->message_type();
const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- PyObject* value_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self), value_type->message_type());
+ CMessageClass* value_class = message_factory::GetMessageClass(
+ GetFactoryForMessage(self), value_type->message_type());
if (value_class == NULL) {
return NULL;
}
@@ -2647,8 +2747,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
PyObject* py_container = NULL;
if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- PyObject *message_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self), field_descriptor->message_type());
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ GetFactoryForMessage(self), field_descriptor->message_type());
if (message_class == NULL) {
return NULL;
}
@@ -2683,7 +2783,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
return InternalGetScalar(self->message, field_descriptor);
}
-int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
+int SetAttr(PyObject* pself, PyObject* name, PyObject* value) {
+ CMessage* self = reinterpret_cast<CMessage*>(pself);
if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) {
PyErr_SetString(PyExc_TypeError, "Can't set composite field");
return -1;
@@ -2711,7 +2812,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
PyErr_Format(PyExc_AttributeError,
"Assignment not allowed "
- "(no field \"%s\"in protocol message object).",
+ "(no field \"%s\" in protocol message object).",
PyString_AsString(name));
return -1;
}
@@ -2719,7 +2820,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
} // namespace cmessage
PyTypeObject CMessage_Type = {
- PyVarObject_HEAD_INIT(&PyMessageMeta_Type, 0)
+ PyVarObject_HEAD_INIT(&CMessageClass_Type, 0)
FULL_MODULE_NAME ".CMessage", // tp_name
sizeof(CMessage), // tp_basicsize
0, // tp_itemsize
@@ -2728,22 +2829,22 @@ PyTypeObject CMessage_Type = {
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ (reprfunc)cmessage::ToStr, // tp_repr
0, // tp_as_number
0, // tp_as_sequence
0, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
0, // tp_call
(reprfunc)cmessage::ToStr, // tp_str
- (getattrofunc)cmessage::GetAttr, // tp_getattro
- (setattrofunc)cmessage::SetAttr, // tp_setattro
+ cmessage::GetAttr, // tp_getattro
+ cmessage::SetAttr, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
"A ProtocolMessage", // tp_doc
0, // tp_traverse
0, // tp_clear
(richcmpfunc)cmessage::RichCompare, // tp_richcompare
- 0, // tp_weaklistoffset
+ offsetof(CMessage, weakreflist), // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
cmessage::Methods, // tp_methods
@@ -2765,17 +2866,38 @@ const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
+ const Message* message = PyMessage_GetMessagePointer(msg);
+ if (message == NULL) {
+ PyErr_Clear();
+ return NULL;
+ }
+ return message;
+}
+
+static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
+ Message* message = PyMessage_GetMutableMessagePointer(msg);
+ if (message == NULL) {
+ PyErr_Clear();
+ return NULL;
+ }
+ return message;
+}
+
+const Message* PyMessage_GetMessagePointer(PyObject* msg) {
if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a Message instance");
return NULL;
}
CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
return cmsg->message;
}
-static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
+Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a Message instance");
return NULL;
}
+
CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) ||
(cmsg->extensions != NULL &&
@@ -2784,36 +2906,20 @@ static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
// the underlying C++ message back to the CMessage (e.g. removed repeated
// composite containers). We only allow direct mutation of the underlying
// C++ message if there is no child data in the CMessage.
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot reliably get a mutable pointer "
+ "to a message with extra references");
return NULL;
}
cmessage::AssureWritable(cmsg);
return cmsg->message;
}
-static const char module_docstring[] =
-"python-proto2 is a module that can be used to enhance proto2 Python API\n"
-"performance.\n"
-"\n"
-"It provides access to the protocol buffers C++ reflection API that\n"
-"implements the basic protocol buffer functions.";
-
void InitGlobals() {
// TODO(gps): Check all return values in this function for NULL and propagate
// the error (MemoryError) on up to result in an import failure. These should
// also be freed and reset to NULL during finalization.
- kPythonZero = PyInt_FromLong(0);
- kint32min_py = PyInt_FromLong(kint32min);
- kint32max_py = PyInt_FromLong(kint32max);
- kuint32max_py = PyLong_FromLongLong(kuint32max);
- kint64min_py = PyLong_FromLongLong(kint64min);
- kint64max_py = PyLong_FromLongLong(kint64max);
- kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max);
-
kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
- k_cdescriptor = PyString_FromString("_cdescriptor");
- kfull_name = PyString_FromString("full_name");
- k_extensions_by_name = PyString_FromString("_extensions_by_name");
- k_extensions_by_number = PyString_FromString("_extensions_by_number");
PyObject *dummy_obj = PySet_New(NULL);
kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
@@ -2831,15 +2937,20 @@ bool InitProto2MessageModule(PyObject *m) {
return false;
}
+ // Initialize types and globals in message_factory.cc
+ if (!InitMessageFactory()) {
+ return false;
+ }
+
// Initialize constants defined in this file.
InitGlobals();
- PyMessageMeta_Type.tp_base = &PyType_Type;
- if (PyType_Ready(&PyMessageMeta_Type) < 0) {
+ CMessageClass_Type.tp_base = &PyType_Type;
+ if (PyType_Ready(&CMessageClass_Type) < 0) {
return false;
}
PyModule_AddObject(m, "MessageMeta",
- reinterpret_cast<PyObject*>(&PyMessageMeta_Type));
+ reinterpret_cast<PyObject*>(&CMessageClass_Type));
if (PyType_Ready(&CMessage_Type) < 0) {
return false;
@@ -2848,25 +2959,6 @@ bool InitProto2MessageModule(PyObject *m) {
// DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
// it here as well to document that subclasses need to set it.
PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
- // Subclasses with message extensions will override _extensions_by_name and
- // _extensions_by_number with fresh mutable dictionaries in AddDescriptors.
- // All other classes can share this same immutable mapping.
- ScopedPyObjectPtr empty_dict(PyDict_New());
- if (empty_dict == NULL) {
- return false;
- }
- ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get()));
- if (immutable_dict == NULL) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_name, immutable_dict.get()) < 0) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_number, immutable_dict.get()) < 0) {
- return false;
- }
PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type));
@@ -2912,69 +3004,15 @@ bool InitProto2MessageModule(PyObject *m) {
}
// Initialize Map container types.
- {
- // ScalarMapContainer_Type derives from our MutableMapping type.
- ScopedPyObjectPtr containers(PyImport_ImportModule(
- "google.protobuf.internal.containers"));
- if (containers == NULL) {
- return false;
- }
-
- ScopedPyObjectPtr mutable_mapping(
- PyObject_GetAttrString(containers.get(), "MutableMapping"));
- if (mutable_mapping == NULL) {
- return false;
- }
-
- if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) {
- return false;
- }
-
- Py_INCREF(mutable_mapping.get());
-#if PY_MAJOR_VERSION >= 3
- PyObject* bases = PyTuple_New(1);
- PyTuple_SET_ITEM(bases, 0, mutable_mapping.get());
-
- ScalarMapContainer_Type =
- PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases);
- PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type);
-#else
- ScalarMapContainer_Type.tp_base =
- reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
-
- if (PyType_Ready(&ScalarMapContainer_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "ScalarMapContainer",
- reinterpret_cast<PyObject*>(&ScalarMapContainer_Type));
-#endif
-
- if (PyType_Ready(&MapIterator_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "MapIterator",
- reinterpret_cast<PyObject*>(&MapIterator_Type));
-
-
-#if PY_MAJOR_VERSION >= 3
- MessageMapContainer_Type =
- PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases);
- PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type);
-#else
- Py_INCREF(mutable_mapping.get());
- MessageMapContainer_Type.tp_base =
- reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
-
- if (PyType_Ready(&MessageMapContainer_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "MessageMapContainer",
- reinterpret_cast<PyObject*>(&MessageMapContainer_Type));
-#endif
+ if (!InitMapContainers()) {
+ return false;
}
+ PyModule_AddObject(m, "ScalarMapContainer",
+ reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
+ PyModule_AddObject(m, "MessageMapContainer",
+ reinterpret_cast<PyObject*>(MessageMapContainer_Type));
+ PyModule_AddObject(m, "MapIterator",
+ reinterpret_cast<PyObject*>(&MapIterator_Type));
if (PyType_Ready(&ExtensionDict_Type) < 0) {
return false;
@@ -3009,6 +3047,10 @@ bool InitProto2MessageModule(PyObject *m) {
&PyFileDescriptor_Type));
PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
&PyOneofDescriptor_Type));
+ PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>(
+ &PyServiceDescriptor_Type));
+ PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>(
+ &PyMethodDescriptor_Type));
PyObject* enum_type_wrapper = PyImport_ImportModule(
"google.protobuf.internal.enum_type_wrapper");
@@ -3045,47 +3087,4 @@ bool InitProto2MessageModule(PyObject *m) {
} // namespace python
} // namespace protobuf
-
-
-#if PY_MAJOR_VERSION >= 3
-static struct PyModuleDef _module = {
- PyModuleDef_HEAD_INIT,
- "_message",
- google::protobuf::python::module_docstring,
- -1,
- NULL,
- NULL,
- NULL,
- NULL,
- NULL
-};
-#define INITFUNC PyInit__message
-#define INITFUNC_ERRORVAL NULL
-#else // Python 2
-#define INITFUNC init_message
-#define INITFUNC_ERRORVAL
-#endif
-
-extern "C" {
- PyMODINIT_FUNC INITFUNC(void) {
- PyObject* m;
-#if PY_MAJOR_VERSION >= 3
- m = PyModule_Create(&_module);
-#else
- m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring);
-#endif
- if (m == NULL) {
- return INITFUNC_ERRORVAL;
- }
-
- if (!google::protobuf::python::InitProto2MessageModule(m)) {
- Py_DECREF(m);
- return INITFUNC_ERRORVAL;
- }
-
-#if PY_MAJOR_VERSION >= 3
- return m;
-#endif
- }
-}
} // namespace google