aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/google/protobuf/extension_set_heavy.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/google/protobuf/extension_set_heavy.cc')
-rw-r--r--src/google/protobuf/extension_set_heavy.cc257
1 files changed, 237 insertions, 20 deletions
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc
index 8555f6f8..426ec5f4 100644
--- a/src/google/protobuf/extension_set_heavy.cc
+++ b/src/google/protobuf/extension_set_heavy.cc
@@ -40,6 +40,7 @@
#include <google/protobuf/message.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/wire_format.h>
+#include <google/protobuf/wire_format_lite_inl.h>
namespace google {
namespace protobuf {
@@ -58,12 +59,26 @@ void ExtensionSet::AppendToList(const Descriptor* containing_type,
}
if (has) {
- output->push_back(
- pool->FindExtensionByNumber(containing_type, iter->first));
+ // TODO(kenton): Looking up each field by number is somewhat unfortunate.
+ // Is there a better way? The problem is that descriptors are lazily-
+ // initialized, so they might not even be constructed until
+ // AppendToList() is called.
+
+ if (iter->second.descriptor == NULL) {
+ output->push_back(pool->FindExtensionByNumber(
+ containing_type, iter->first));
+ } else {
+ output->push_back(iter->second.descriptor);
+ }
}
}
}
+inline FieldDescriptor::Type real_type(FieldType type) {
+ GOOGLE_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE);
+ return static_cast<FieldDescriptor::Type>(type);
+}
+
inline FieldDescriptor::CppType cpp_type(FieldType type) {
return FieldDescriptor::TypeToCppType(
static_cast<FieldDescriptor::Type>(type));
@@ -88,16 +103,16 @@ const MessageLite& ExtensionSet::GetMessage(int number,
}
}
-MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
- const Descriptor* message_type,
+MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
MessageFactory* factory) {
Extension* extension;
- if (MaybeNewExtension(number, &extension)) {
- extension->type = type;
+ if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
+ extension->type = descriptor->type();
GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
extension->is_repeated = false;
extension->is_packed = false;
- const MessageLite* prototype = factory->GetPrototype(message_type);
+ const MessageLite* prototype =
+ factory->GetPrototype(descriptor->message_type());
GOOGLE_CHECK(prototype != NULL);
extension->message_value = prototype->New();
} else {
@@ -107,12 +122,11 @@ MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
return extension->message_value;
}
-MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
- const Descriptor* message_type,
+MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
MessageFactory* factory) {
Extension* extension;
- if (MaybeNewExtension(number, &extension)) {
- extension->type = type;
+ if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) {
+ extension->type = descriptor->type();
GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE);
extension->is_repeated = true;
extension->repeated_message_value =
@@ -128,7 +142,7 @@ MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
if (result == NULL) {
const MessageLite* prototype;
if (extension->repeated_message_value->size() == 0) {
- prototype = factory->GetPrototype(message_type);
+ prototype = factory->GetPrototype(descriptor->message_type());
GOOGLE_CHECK(prototype != NULL);
} else {
prototype = &extension->repeated_message_value->Get(0);
@@ -139,18 +153,64 @@ MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
return result;
}
+static bool ValidateEnumUsingDescriptor(const void* arg, int number) {
+ return reinterpret_cast<const EnumDescriptor*>(arg)
+ ->FindValueByNumber(number) != NULL;
+}
+
+bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) {
+ const FieldDescriptor* extension =
+ pool_->FindExtensionByNumber(containing_type_, number);
+ if (extension == NULL) {
+ return false;
+ } else {
+ output->type = extension->type();
+ output->is_repeated = extension->is_repeated();
+ output->is_packed = extension->options().packed();
+ output->descriptor = extension;
+ if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ output->message_prototype =
+ factory_->GetPrototype(extension->message_type());
+ GOOGLE_CHECK(output->message_prototype != NULL)
+ << "Extension factory's GetPrototype() returned NULL for extension: "
+ << extension->full_name();
+ } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
+ output->enum_is_valid = ValidateEnumUsingDescriptor;
+ output->enum_is_valid_arg = extension->enum_type();
+ }
+
+ return true;
+ }
+}
+
bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
- const MessageLite* containing_type,
+ const Message* containing_type,
UnknownFieldSet* unknown_fields) {
UnknownFieldSetFieldSkipper skipper(unknown_fields);
- return ParseField(tag, input, containing_type, &skipper);
+ if (input->GetExtensionPool() == NULL) {
+ GeneratedExtensionFinder finder(containing_type);
+ return ParseField(tag, input, &finder, &skipper);
+ } else {
+ DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
+ input->GetExtensionFactory(),
+ containing_type->GetDescriptor());
+ return ParseField(tag, input, &finder, &skipper);
+ }
}
bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
- const MessageLite* containing_type,
+ const Message* containing_type,
UnknownFieldSet* unknown_fields) {
UnknownFieldSetFieldSkipper skipper(unknown_fields);
- return ParseMessageSet(input, containing_type, &skipper);
+ if (input->GetExtensionPool() == NULL) {
+ GeneratedExtensionFinder finder(containing_type);
+ return ParseMessageSet(input, &finder, &skipper);
+ } else {
+ DescriptorPoolExtensionFinder finder(input->GetExtensionPool(),
+ input->GetExtensionFactory(),
+ containing_type->GetDescriptor());
+ return ParseMessageSet(input, &finder, &skipper);
+ }
}
int ExtensionSet::SpaceUsedExcludingSelf() const {
@@ -175,7 +235,7 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
if (is_repeated) {
switch (cpp_type(type)) {
#define HANDLE_TYPE(UPPERCASE, LOWERCASE) \
- case WireFormatLite::CPPTYPE_##UPPERCASE: \
+ case FieldDescriptor::CPPTYPE_##UPPERCASE: \
total_size += sizeof(*repeated_##LOWERCASE##_value) + \
repeated_##LOWERCASE##_value->SpaceUsedExcludingSelf();\
break
@@ -189,8 +249,9 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
HANDLE_TYPE( BOOL, bool);
HANDLE_TYPE( ENUM, enum);
HANDLE_TYPE( STRING, string);
+#undef HANDLE_TYPE
- case WireFormatLite::CPPTYPE_MESSAGE:
+ case FieldDescriptor::CPPTYPE_MESSAGE:
// repeated_message_value is actually a RepeatedPtrField<MessageLite>,
// but MessageLite has no SpaceUsed(), so we must directly call
// RepeatedPtrFieldBase::SpaceUsedExcludingSelf() with a different type
@@ -201,11 +262,11 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
}
} else {
switch (cpp_type(type)) {
- case WireFormatLite::CPPTYPE_STRING:
+ case FieldDescriptor::CPPTYPE_STRING:
total_size += sizeof(*string_value) +
StringSpaceUsedExcludingSelf(*string_value);
break;
- case WireFormatLite::CPPTYPE_MESSAGE:
+ case FieldDescriptor::CPPTYPE_MESSAGE:
total_size += down_cast<Message*>(message_value)->SpaceUsed();
break;
default:
@@ -216,6 +277,162 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
return total_size;
}
+// The Serialize*ToArray methods are only needed in the heavy library, as
+// the lite library only generates SerializeWithCachedSizes.
+uint8* ExtensionSet::SerializeWithCachedSizesToArray(
+ int start_field_number, int end_field_number,
+ uint8* target) const {
+ map<int, Extension>::const_iterator iter;
+ for (iter = extensions_.lower_bound(start_field_number);
+ iter != extensions_.end() && iter->first < end_field_number;
+ ++iter) {
+ target = iter->second.SerializeFieldWithCachedSizesToArray(iter->first,
+ target);
+ }
+ return target;
+}
+
+uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray(
+ uint8* target) const {
+ map<int, Extension>::const_iterator iter;
+ for (iter = extensions_.begin(); iter != extensions_.end(); ++iter) {
+ target = iter->second.SerializeMessageSetItemWithCachedSizesToArray(
+ iter->first, target);
+ }
+ return target;
+}
+
+uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray(
+ int number, uint8* target) const {
+ if (is_repeated) {
+ if (is_packed) {
+ if (cached_size == 0) return target;
+
+ target = WireFormatLite::WriteTagToArray(number,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
+ target = WireFormatLite::WriteInt32NoTagToArray(cached_size, target);
+
+ switch (real_type(type)) {
+#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
+ case FieldDescriptor::TYPE_##UPPERCASE: \
+ for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) { \
+ target = WireFormatLite::Write##CAMELCASE##NoTagToArray( \
+ repeated_##LOWERCASE##_value->Get(i), target); \
+ } \
+ break
+
+ HANDLE_TYPE( INT32, Int32, int32);
+ HANDLE_TYPE( INT64, Int64, int64);
+ HANDLE_TYPE( UINT32, UInt32, uint32);
+ HANDLE_TYPE( UINT64, UInt64, uint64);
+ HANDLE_TYPE( SINT32, SInt32, int32);
+ HANDLE_TYPE( SINT64, SInt64, int64);
+ HANDLE_TYPE( FIXED32, Fixed32, uint32);
+ HANDLE_TYPE( FIXED64, Fixed64, uint64);
+ HANDLE_TYPE(SFIXED32, SFixed32, int32);
+ HANDLE_TYPE(SFIXED64, SFixed64, int64);
+ HANDLE_TYPE( FLOAT, Float, float);
+ HANDLE_TYPE( DOUBLE, Double, double);
+ HANDLE_TYPE( BOOL, Bool, bool);
+ HANDLE_TYPE( ENUM, Enum, enum);
+#undef HANDLE_TYPE
+
+ case WireFormatLite::TYPE_STRING:
+ case WireFormatLite::TYPE_BYTES:
+ case WireFormatLite::TYPE_GROUP:
+ case WireFormatLite::TYPE_MESSAGE:
+ GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
+ break;
+ }
+ } else {
+ switch (real_type(type)) {
+#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \
+ case FieldDescriptor::TYPE_##UPPERCASE: \
+ for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) { \
+ target = WireFormatLite::Write##CAMELCASE##ToArray(number, \
+ repeated_##LOWERCASE##_value->Get(i), target); \
+ } \
+ break
+
+ HANDLE_TYPE( INT32, Int32, int32);
+ HANDLE_TYPE( INT64, Int64, int64);
+ HANDLE_TYPE( UINT32, UInt32, uint32);
+ HANDLE_TYPE( UINT64, UInt64, uint64);
+ HANDLE_TYPE( SINT32, SInt32, int32);
+ HANDLE_TYPE( SINT64, SInt64, int64);
+ HANDLE_TYPE( FIXED32, Fixed32, uint32);
+ HANDLE_TYPE( FIXED64, Fixed64, uint64);
+ HANDLE_TYPE(SFIXED32, SFixed32, int32);
+ HANDLE_TYPE(SFIXED64, SFixed64, int64);
+ HANDLE_TYPE( FLOAT, Float, float);
+ HANDLE_TYPE( DOUBLE, Double, double);
+ HANDLE_TYPE( BOOL, Bool, bool);
+ HANDLE_TYPE( STRING, String, string);
+ HANDLE_TYPE( BYTES, Bytes, string);
+ HANDLE_TYPE( ENUM, Enum, enum);
+ HANDLE_TYPE( GROUP, Group, message);
+ HANDLE_TYPE( MESSAGE, Message, message);
+#undef HANDLE_TYPE
+ }
+ }
+ } else if (!is_cleared) {
+ switch (real_type(type)) {
+#define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE) \
+ case FieldDescriptor::TYPE_##UPPERCASE: \
+ target = WireFormatLite::Write##CAMELCASE##ToArray( \
+ number, VALUE, target); \
+ break
+
+ HANDLE_TYPE( INT32, Int32, int32_value);
+ HANDLE_TYPE( INT64, Int64, int64_value);
+ HANDLE_TYPE( UINT32, UInt32, uint32_value);
+ HANDLE_TYPE( UINT64, UInt64, uint64_value);
+ HANDLE_TYPE( SINT32, SInt32, int32_value);
+ HANDLE_TYPE( SINT64, SInt64, int64_value);
+ HANDLE_TYPE( FIXED32, Fixed32, uint32_value);
+ HANDLE_TYPE( FIXED64, Fixed64, uint64_value);
+ HANDLE_TYPE(SFIXED32, SFixed32, int32_value);
+ HANDLE_TYPE(SFIXED64, SFixed64, int64_value);
+ HANDLE_TYPE( FLOAT, Float, float_value);
+ HANDLE_TYPE( DOUBLE, Double, double_value);
+ HANDLE_TYPE( BOOL, Bool, bool_value);
+ HANDLE_TYPE( STRING, String, *string_value);
+ HANDLE_TYPE( BYTES, Bytes, *string_value);
+ HANDLE_TYPE( ENUM, Enum, enum_value);
+ HANDLE_TYPE( GROUP, Group, *message_value);
+ HANDLE_TYPE( MESSAGE, Message, *message_value);
+#undef HANDLE_TYPE
+ }
+ }
+ return target;
+}
+
+uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray(
+ int number,
+ uint8* target) const {
+ if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
+ // Not a valid MessageSet extension, but serialize it the normal way.
+ GOOGLE_LOG(WARNING) << "Invalid message set extension.";
+ return SerializeFieldWithCachedSizesToArray(number, target);
+ }
+
+ if (is_cleared) return target;
+
+ // Start group.
+ target = io::CodedOutputStream::WriteTagToArray(
+ WireFormatLite::kMessageSetItemStartTag, target);
+ // Write type ID.
+ target = WireFormatLite::WriteUInt32ToArray(
+ WireFormatLite::kMessageSetTypeIdNumber, number, target);
+ // Write message.
+ target = WireFormatLite::WriteMessageToArray(
+ WireFormatLite::kMessageSetMessageNumber, *message_value, target);
+ // End group.
+ target = io::CodedOutputStream::WriteTagToArray(
+ WireFormatLite::kMessageSetItemEndTag, target);
+ return target;
+}
+
} // namespace internal
} // namespace protobuf
} // namespace google