diff options
Diffstat (limited to 'python/google/protobuf/pyext/extension_dict.cc')
-rw-r--r-- | python/google/protobuf/pyext/extension_dict.cc | 47 |
1 files changed, 33 insertions, 14 deletions
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index b361b342..555bd293 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -94,13 +94,13 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor == NULL) { return NULL; } - if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return NULL; } if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { - return cmessage::InternalGetScalar(self->parent->message, descriptor); + return cmessage::InternalGetScalar(self->message, descriptor); } PyObject* value = PyDict_GetItem(self->values, key); @@ -109,6 +109,14 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { return value; } + if (self->parent == NULL) { + // We are in "detached" state. Don't allow further modifications. + // TODO(amauryfa): Support adding non-scalars to a detached extension dict. + // This probably requires to store the type of the main message. + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* sub_message = cmessage::InternalGetSubMessage( @@ -154,7 +162,7 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { if (descriptor == NULL) { return -1; } - if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return -1; } @@ -164,9 +172,11 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { "type"); return -1; } - cmessage::AssureWritable(self->parent); - if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { - return -1; + if (self->parent) { + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; + } } // TODO(tibell): We shouldn't write scalars to the cache. PyDict_SetItem(self->values, key, value); @@ -180,15 +190,17 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { return NULL; } PyObject* value = PyDict_GetItem(self->values, extension); - if (value != NULL) { - if (ReleaseExtension(self, value, descriptor) < 0) { + if (self->parent) { + if (value != NULL) { + if (ReleaseExtension(self, value, descriptor) < 0) { + return NULL; + } + } + if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( + self->parent, descriptor)) == NULL) { return NULL; } } - if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( - self->parent, descriptor)) == NULL) { - return NULL; - } if (PyDict_DelItem(self->values, extension) < 0) { PyErr_Clear(); } @@ -201,8 +213,15 @@ PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { if (descriptor == NULL) { return NULL; } - PyObject* result = cmessage::HasFieldByDescriptor(self->parent, descriptor); - return result; + if (self->parent) { + return cmessage::HasFieldByDescriptor(self->parent, descriptor); + } else { + int exists = PyDict_Contains(self->values, extension); + if (exists < 0) { + return NULL; + } + return PyBool_FromLong(exists); + } } PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { |